diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index 18d9d3867e..db998668dd 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -103,13 +103,14 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTe EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTests.netcoreapp", "test\Microsoft.ML.CpuMath.UnitTests.netcoreapp\Microsoft.ML.CpuMath.UnitTests.netcoreapp.csproj", "{5F81A2A4-73AD-494C-B387-07D605EC8826}" EndProject - -Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}" +Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow", "src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj", "{570A0B8A-5463-44D2-8521-54C0CA4CACA9}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -390,6 +391,14 @@ Global {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU {A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.Build.0 = Debug|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.ActiveCfg = Release|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.Build.0 = Release|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -426,14 +435,15 @@ Global {001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592} + {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E} + {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {7333EDEF-4144-405C-A5EC-6F42201857D8} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {A0E562A9-0E6D-470D-B180-6EB44BA84D60} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {5F81A2A4-73AD-494C-B387-07D605EC8826} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} - {B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E} - {3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} {00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/build.proj b/build.proj index c9be14e930..ef922a7148 100644 --- a/build.proj +++ b/build.proj @@ -31,6 +31,7 @@ CreateOrUpdateCurrentVersionFile; RestoreProjects; + BuildRedist; BuildNative; $(TraversalBuildDependsOn); DownloadExternalTestFiles; @@ -44,9 +45,17 @@ Properties="MSBuildWarningsAsMessages=NU1503" /> + + + + + + DependsOnTargets="RestoreProjects;BuildRedist"> diff --git a/build/Dependencies.props b/build/Dependencies.props index 3a46917114..0b6af3cdc9 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -11,5 +11,6 @@ 0.0.0.5 4.5.0 0.11.0 + 1.10.0 diff --git a/build/sign.proj b/build/sign.proj index 8f8523474c..498ed7b433 100644 --- a/build/sign.proj +++ b/build/sign.proj @@ -30,7 +30,10 @@ - + + + + Microsoft diff --git a/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj new file mode 100644 index 0000000000..0de839bf08 --- /dev/null +++ b/pkg/Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj @@ -0,0 +1,21 @@ + + + + The TensorFlow Authors + netstandard2.0 + $(MSBuildProjectName) contains the TensorFlow C library version $(TensorFlowVersion) redistributed as a NuGet package. + https://github.com/tensorflow/tensorflow/blob/master/LICENSE + true + Copyright 2018 The TensorFlow Authors. All rights reserved. + https://www.tensorflow.org + https://github.com/tensorflow/tensorflow/releases/tag/v$(TensorFlowVersion) + $(PackageTags) TensorFlow + + + + + + + + + diff --git a/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj new file mode 100644 index 0000000000..fb07384f90 --- /dev/null +++ b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj @@ -0,0 +1,12 @@ + + + + netstandard2.0 + Microsoft.ML.TensorFlow contains ML.NET integration of TensorFlow. + + + + + + + diff --git a/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.symbols.nupkgproj b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.symbols.nupkgproj new file mode 100644 index 0000000000..a2a2a153f7 --- /dev/null +++ b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.symbols.nupkgproj @@ -0,0 +1,5 @@ + + + + + diff --git a/src/Directory.Build.props b/src/Directory.Build.props index 113da3575a..ee32523d8e 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -12,6 +12,15 @@ $(WarningsNotAsErrors);1591 $(MSBuildThisFileDirectory)\Source.ruleset + + x64 + + $(BaseOutputPath)$(TargetArchitecture).$(Configuration)\Native + + win + linux + osx + $(PackageRid)-$(TargetArchitecture) diff --git a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj new file mode 100644 index 0000000000..c04bdde2da --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj @@ -0,0 +1,31 @@ + + + + netstandard2.0 + Microsoft.ML.TensorFlow + CORECLR + true + + + + + + + + + + True + True + TensorGeneric.tt + + + TextTemplatingFileGenerator + TensorGeneric.cs + + + + + + + + diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Buffer.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Buffer.cs new file mode 100644 index 0000000000..28d7879ba5 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Buffer.cs @@ -0,0 +1,211 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; +using System.Text; +using size_t = System.UIntPtr; + +#pragma warning disable MSML_GeneralName +#pragma warning disable MSML_ParameterLocalVarName + +namespace Microsoft.ML.Transforms.TensorFlow +{ + /// + /// This attribute can be applied to callback functions that will be invoked + /// from unmanaged code to managed code. + /// + /// + /// + /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] + /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} + /// + /// + internal sealed class MonoPInvokeCallbackAttribute : Attribute + { + /// + /// Use this constructor to annotate the type of the callback function that + /// will be invoked from unmanaged code. + /// + /// T. + public MonoPInvokeCallbackAttribute(Type t) { } + } + + [StructLayout(LayoutKind.Sequential)] + internal struct LLBuffer + { + internal IntPtr data; + internal size_t length; + internal IntPtr data_deallocator; + } + + /// + /// Holds a block of data, suitable to pass, or retrieve from TensorFlow. + /// + /// + /// + /// Use the TFBuffer to blobs of data into TensorFlow, or to retrieve blocks + /// of data out of TensorFlow. + /// + /// + /// There are two constructors to wrap existing data, one to wrap blocks that are + /// pointed to by an IntPtr and one that takes a byte array that we want to wrap. + /// + /// + /// The empty constructor can be used to create a new TFBuffer that can be populated + /// by the TensorFlow library and returned to user code. + /// + /// + /// Typically, the data consists of a serialized protocol buffer, but other data + /// may also be held in a buffer. + /// + /// + // TODO: the string ctor + // TODO: perhaps we should have an implicit byte [] conversion that just calls ToArray? + internal class TFBuffer : TFDisposable + { + // extern TF_Buffer * TF_NewBufferFromString (const void *proto, size_t proto_len); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe LLBuffer* TF_NewBufferFromString(IntPtr proto, IntPtr proto_len); + + // extern TF_Buffer * TF_NewBuffer (); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe LLBuffer* TF_NewBuffer(); + + internal TFBuffer(IntPtr handle) : base(handle) { } + + /// + /// Initializes a new instance of the class. + /// + public unsafe TFBuffer() : base((IntPtr)TF_NewBuffer()) + { + } + + /// + /// Signature of the method that is invoked to release the data. + /// + /// + /// Methods of this signature are invoked with the data pointer and the + /// lenght pointer when then TFBuffer no longer needs to hold on to the + /// data. If you are using this on platforms with static compilation + /// like iOS, you need to annotate your callback with the MonoPInvokeCallbackAttribute, + /// like this: + /// + /// + /// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))] + /// internal static void MyFreeFunc (IntPtr data, IntPtr length){..} + /// + /// + public delegate void BufferReleaseFunc(IntPtr data, IntPtr lenght); + + /// + /// Initializes a new instance of the by wrapping the unmanaged resource pointed by the buffer. + /// + /// Pointer to the data that will be wrapped. + /// The size of the buffer to wrap. + /// Optional, if not null, this method will be invoked to release the block. + /// + /// This constructor wraps the buffer as a the data to be held by the , + /// if the release parameter is null, then you must ensure that the data is not released before the TFBuffer + /// is no longer in use. If the value is not null, the provided method will be invoked to release + /// the data when the TFBuffer is disposed, or the contents of the buffer replaced. + /// + public unsafe TFBuffer(IntPtr buffer, long size, BufferReleaseFunc release) : base((IntPtr)TF_NewBuffer()) + { + LLBuffer* buf = (LLBuffer*)handle; + buf->data = buffer; + buf->length = (size_t)size; + if (release == null) + buf->data_deallocator = IntPtr.Zero; + else + buf->data_deallocator = Marshal.GetFunctionPointerForDelegate(release); + } + + [MonoPInvokeCallback(typeof(BufferReleaseFunc))] + internal static void FreeBlock(IntPtr data, IntPtr length) + { + Marshal.FreeHGlobal(data); + } + + internal static IntPtr FreeBufferFunc; + internal static BufferReleaseFunc FreeBlockDelegate; + + static TFBuffer() + { + FreeBlockDelegate = FreeBlock; + FreeBufferFunc = Marshal.GetFunctionPointerForDelegate(FreeBlockDelegate); + } + + /// + /// Initializes a new instance of the by making a copy of the provided byte array. + /// + /// Buffer of data that will be wrapped. + /// + /// This constructor makes a copy of the data into an unmanaged buffer, + /// so the byte array is not pinned. + /// + public TFBuffer(byte[] buffer) : this(buffer, 0, buffer.Length) { } + + /// + /// Initializes a new instance of the by making a copy of the provided byte array. + /// + /// Buffer of data that will be wrapped. + /// Starting offset into the buffer to wrap. + /// Number of bytes from the buffer to keep. + /// + /// This constructor makes a copy of the data into an unmanaged buffer, + /// so the byte array is not pinned. + /// + public TFBuffer(byte[] buffer, int start, int count) : this() + { + if (start < 0 || start >= buffer.Length) + throw new ArgumentException("start"); + if (count < 0 || count > buffer.Length - start) + throw new ArgumentException("count"); + unsafe + { + LLBuffer* buf = LLBuffer; + buf->data = Marshal.AllocHGlobal(count); + Marshal.Copy(buffer, start, buf->data, count); + buf->length = (size_t)count; + buf->data_deallocator = FreeBufferFunc; + } + } + + internal unsafe LLBuffer* LLBuffer => (LLBuffer*)handle; + + // extern void TF_DeleteBuffer (TF_Buffer *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteBuffer(LLBuffer* buffer); + + internal override void NativeDispose(IntPtr handle) + { + unsafe { TF_DeleteBuffer((LLBuffer*)handle); } + } + + // extern TF_Buffer TF_GetBuffer (TF_Buffer *buffer); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe LLBuffer TF_GetBuffer(LLBuffer* buffer); + + /// + /// Returns a byte array representing the data wrapped by this buffer. + /// + /// The array. + public byte[] ToArray() + { + if (handle == IntPtr.Zero) + return null; + + unsafe + { + var lb = (LLBuffer*)handle; + + var result = new byte[(int)lb->length]; + Marshal.Copy(lb->data, result, 0, (int)lb->length); + + return result; + } + } + } +} diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs new file mode 100644 index 0000000000..02778373e4 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensor.cs @@ -0,0 +1,979 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; +using System.Numerics; +using System.Runtime.InteropServices; +using System.Text; +using size_t = System.UIntPtr; +using TF_Tensor = System.IntPtr; + +#pragma warning disable MSML_ParameterLocalVarName + +namespace Microsoft.ML.Transforms.TensorFlow +{ + + /// + /// TFTensor holds a multi-dimensional array of elements of a single data type. + /// + /// + /// + /// You can create tensors with the various constructors in this class, or using + /// the implicit conversions from various data types into a TFTensor, including + /// the creation of tensors from simple constants (returning a tensor that reprensets + /// a scalar, that is, it is a 0D tensor), arrays (returning a tensor of a single + /// dimension, 1D) or arbitrary multidimensional arrays. + /// + /// + /// Given a tensor, you can retrieve the number of dimensions in it via the + /// NumDims property, or you can retrieve the shape of a tensor, that is how many + /// elements on each dimension the tensor has, by fetching the Shape property. + /// + /// + /// The implicit conversions for basic types produce tensors of one dimesion with + /// a single element, while the implicit conversion from an array, expects a multi-dimensional + /// array that is converted into a tensor of the right dimensions. + /// + /// + /// The special "String" tensor data type that you will find in TensorFlow documentation + /// really represents a byte array. You can create string tensors by using the + /// method that takes a byte array buffer as input. + /// + /// + /// + /// TFTensor scalar = 1; // Creates a 0D tensor, for the integer value 1 + /// int d = scalar.NumDims; // d will be equal to zero, as it is a 0D tensor + /// long [] shape = scalar.Shape // returns an empty array, as it is a 0D tensor + /// + /// TFTensor list = new [] {1,2,3} // Creates a 1D tensor, or vector, for the values 1, 2, 3 + /// d = list.NumDims; // d will be one + /// shape = list.Shape; // shape will be an array with a single value 3, representing that the dimension 0 has 3 elements + /// + /// // Creates a 3D tensor, + /// TFTensor cube = new [,,] { {{1,2,3},{4,5,6}}} + /// d = cube.NumDims // d will be 3 + /// shape = list.Shape // shape will be [1,2,3] which is the shape of the above 3D array + /// + /// + /// + internal partial class TFTensor : TFDisposableThreadSafe + { + /// + /// Signature that methods must conform to to be used to release memory that was passed to a manually allocated TFTensor + /// + public delegate void Deallocator(IntPtr data, IntPtr size, IntPtr deallocatorData); + + // extern TF_Tensor * TF_NewTensor (TF_DataType, const int64_t *dims, int num_dims, void *data, size_t len, void (* deallocator)(void *, size_t, void *), void *deallocator_arg); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Tensor TF_NewTensor(TFDataType dataType, long[] dims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Tensor TF_NewTensor(TFDataType dataType, IntPtr zeroDims, int num_dims, IntPtr data, size_t len, Deallocator deallocator, IntPtr deallocator_arg); + + internal TFTensor(IntPtr handle) : base(handle) { } + + internal static Deallocator FreeTensorDataDelegate = FreeTensorData; + internal static Deallocator FreeTensorHandleDelegate = FreeTensorHandle; + + [MonoPInvokeCallback(typeof(Deallocator))] + internal static void FreeTensorData(IntPtr data, IntPtr len, IntPtr closure) + { + Marshal.FreeHGlobal(data); + } + + [MonoPInvokeCallback(typeof(Deallocator))] + internal static void FreeTensorHandle(IntPtr data, IntPtr len, IntPtr closure) + { + var gch = GCHandle.FromIntPtr(closure); + gch.Free(); + } + + // TODO: Other overloads we could add: String, Complex (float), Bool, QInt8, QUInt8, QInt32, Bfloat16, + // QInt16, QUint16, Half, Resource + // TODO: not clear that this is very useful (the dims versions), perhaps to reduce the surface of + // construcors these rarer blobs should be "FromSpec" or something like that + + /// + /// Creates a new tensor from a portion of an array of sbytes + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, sbyte[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Int8, shape, data, start, count, size: 2)); + } + + /// + /// Creates a new tensor from a portion of an array of bytes + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, byte[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.UInt8, shape, data, start, count, size: 1)); + } + + /// + /// Creates a new tensor from a portion of an array of shorts + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, short[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Int16, shape, data, start, count, size: 2)); + } + + /// + /// Creates a new tensor from a portion of an array of ushorts + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, ushort[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.UInt16, shape, data, start, count, size: 2)); + } + + /// + /// Creates a new tensor from a portion of an array of ints + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, int[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Int32, shape, data, start, count, size: 4)); + } + + /// + /// Creates a new tensor from a portion of an array of floats + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, float[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Float, shape, data, start, count, size: 4)); + } + + /// + /// Creates a new tensor from a portion of an array of doubles + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, double[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Double, shape, data, start, count, size: 8)); + } + + /// + /// Creates a new tensor from a portion of an array of longs + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, long[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Int64, shape, data, start, count, size: 8)); + } + + /// + /// Creates a new tensor from a portion of an array of Complex numbers + /// + /// Represents the tensor shape. + /// The linear array of data, the data is shuffled to fit in the tensor with the specified dimensions. + /// The offset into the provided data array where the data resides. + /// The number of bytes to copy from count into the tensor. + /// + /// Use the FromBuffer method to create a tensor that has the specified dimensions + /// and is initialized with data from the data array. The data is copied starting + /// at the start offset, for count bytes and is laid out into the tensor following the + /// specified dimensions. + /// + public static TFTensor FromBuffer(TFShape shape, Complex[] data, int start, int count) + { + return new TFTensor(SetupTensor(TFDataType.Complex128, shape, data, start, count, size: 16)); + } + + /// + /// Creates a constant tensor with a single dimension from an integer value. + /// + public unsafe TFTensor(int value) + { + var v = (int*)Marshal.AllocHGlobal(sizeof(int)); + *v = value; + handle = TF_NewTensor(TFDataType.Int32, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from a boolean value. + /// + public unsafe TFTensor(bool value) + { + var v = (bool*)Marshal.AllocHGlobal(sizeof(bool)); + *v = value; + handle = TF_NewTensor(TFDataType.Bool, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(int), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from an sbyte value. + /// + public unsafe TFTensor(sbyte value) + { + var v = (sbyte*)Marshal.AllocHGlobal(sizeof(sbyte)); + *v = value; + handle = TF_NewTensor(TFDataType.Int8, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(sbyte), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from a short value. + /// + public unsafe TFTensor(short value) + { + var v = (short*)Marshal.AllocHGlobal(sizeof(short)); + *v = value; + handle = TF_NewTensor(TFDataType.Int16, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(short), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from an ushort value. + /// + public unsafe TFTensor(ushort value) + { + var v = (ushort*)Marshal.AllocHGlobal(sizeof(ushort)); + *v = value; + handle = TF_NewTensor(TFDataType.Int16, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(ushort), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from an byte value. + /// + public unsafe TFTensor(byte value) + { + var v = (int*)Marshal.AllocHGlobal(sizeof(byte)); + *v = value; + handle = TF_NewTensor(TFDataType.UInt8, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(byte), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from a Complex value. + /// + public unsafe TFTensor(Complex value) + { + var v = (Complex*)Marshal.AllocHGlobal(sizeof(Complex)); + *v = value; + handle = TF_NewTensor(TFDataType.Complex128, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(Complex), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from a float value. + /// + public unsafe TFTensor(float value) + { + var v = (float*)Marshal.AllocHGlobal(sizeof(float)); + *v = value; + handle = TF_NewTensor(TFDataType.Float, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(float), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from a double value. + /// + public unsafe TFTensor(double value) + { + var v = (double*)Marshal.AllocHGlobal(sizeof(double)); + *v = value; + handle = TF_NewTensor(TFDataType.Double, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(double), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + + /// + /// Creates a constant tensor with a single dimension from a long value. + /// + public unsafe TFTensor(long value) + { + var v = (long*)Marshal.AllocHGlobal(sizeof(long)); + *v = value; + handle = TF_NewTensor(TFDataType.Int64, zeroDims: IntPtr.Zero, num_dims: 0, data: (IntPtr)v, len: (UIntPtr)sizeof(long), deallocator: FreeTensorDataDelegate, deallocator_arg: IntPtr.Zero); + } + /// + /// Creates a 1 dimensional tensor from an array of booleans. + /// + /// Data. + public TFTensor(bool[] data) : base(SetupTensor(TFDataType.Bool, data, size: 1)) { } + /// + /// Creates a 1 dimensional tensor from an array of sbytes. + /// + /// Data. + public TFTensor(sbyte[] data) : base(SetupTensor(TFDataType.Int8, data, size: 1)) { } + /// + /// Creates a 1 dimensional tensor from an array of bytes. + /// + /// Data. + public TFTensor(byte[] data) : base(SetupTensor(TFDataType.UInt8, data, size: 1)) { } + /// + /// Creates a 1 dimensional tensor from an array of shorts. + /// + /// Data. + public TFTensor(short[] data) : base(SetupTensor(TFDataType.Int16, data, size: 2)) { } + /// + /// Creates a 1 dimensional tensor from an array of ushorts + /// + /// Data. + public TFTensor(ushort[] data) : base(SetupTensor(TFDataType.UInt16, data, size: 2)) { } + /// + /// Creates a 1 dimensional tensor from an array of ints. + /// + /// Data. + public TFTensor(int[] data) : base(SetupTensor(TFDataType.Int32, data, size: 4)) { } + /// + /// Creates a 1 dimensional tensor from an array of floats. + /// + /// Data. + public TFTensor(float[] data) : base(SetupTensor(TFDataType.Float, data, size: 4)) { } + /// + /// Creates a 1 dimensional tensor from an array of doubles. + /// + /// Data. + public TFTensor(double[] data) : base(SetupTensor(TFDataType.Double, data, size: 8)) { } + /// + /// Creates a 1 dimensional tensor from an array of longs. + /// + /// Data. + public TFTensor(long[] data) : base(SetupTensor(TFDataType.Int64, data, size: 8)) { } + /// + /// Creates a 1 dimensional tensor from an array of complex numbers. + /// + /// Data. + public TFTensor(Complex[] data) : base(SetupTensor(TFDataType.Complex128, data, size: 16)) { } + + // Convenience function to factor out the setup of a new tensor from an array + internal static IntPtr SetupTensor(TFDataType dt, long[] dims, Array data, int size) + { + return SetupTensor(dt, dims, data, start: 0, count: data.Length, size: size); + } + + // Convenience function to factor out the setup of a new tensor from an array + internal static IntPtr SetupTensor(TFDataType dt, Array data, int size) + { + long[] dims = new long[data.Rank]; + for (int i = 0; i < dims.Length; i++) + dims[i] = data.GetLength(i); + + return SetupTensor(dt, dims, data, start: 0, count: data.Length, size: size); + } + + // Use for single dimension arrays + internal static IntPtr SetupTensor(TFDataType dt, TFShape shape, Array data, int start, int count, int size) + { + if (shape == null) + throw new ArgumentNullException(nameof(shape)); + return SetupTensor(dt, shape.dims, data, start, count, size); + } + + // Use for single dimension arrays + internal static IntPtr SetupTensor(TFDataType dt, long[] dims, Array data, int start, int count, int size) + { + if (start < 0 || start > data.Length - count) + throw new ArgumentException("start + count > Array size"); + + var dataHandle = GCHandle.Alloc(data, GCHandleType.Pinned); + + if (dims == null) + return TF_NewTensor(dt, IntPtr.Zero, 0, dataHandle.AddrOfPinnedObject() + start * size, (UIntPtr)(count * size), FreeTensorHandleDelegate, GCHandle.ToIntPtr(dataHandle)); + else + return TF_NewTensor(dt, dims, dims.Length, dataHandle.AddrOfPinnedObject() + start * size, (UIntPtr)(count * size), FreeTensorHandleDelegate, GCHandle.ToIntPtr(dataHandle)); + } + + // General purpose constructor, specifies data type and gets pointer to buffer + // Is the default good, one where we let the user provide their own deallocator, or should we make a copy in that case? + /// + /// Low-level tensor constructor that creates a tensor from a buffer pointed to by an IntPtr. + /// + /// Specifies the data type held by the tensor, as well as how to interpret the provided data. + /// Describes the tensor shape, an array that indicates . + /// Pointer to the raw data that will be used to initialize the tensor. + /// The size of the data being passed in. + /// Deallocator method, it is invoked when the tensor is destroyed to release the data pointed to by . On platforms like iOS (or other static compilation platforms), yiou must annotate the method specified in the deallocator with a . + /// An optional argument of data that is passed to the deallocator method when the tensor is destroyed, you can use this to pass context information. + public TFTensor(TFDataType dataType, long[] dims, IntPtr data, size_t dataSize, Deallocator deallocator, IntPtr deallocatorData) : base(IntPtr.Zero) + { + if (dims == null) + throw new ArgumentNullException("dims"); + + handle = TF_NewTensor(dataType, dims, dims.Length, data, dataSize, deallocator, deallocatorData); + + } + + internal override void NativeDispose(IntPtr handle) + { + TF_DeleteTensor(handle); + } + + // extern TF_Tensor * TF_AllocateTensor (TF_DataType, const int64_t *dims, int num_dims, size_t len); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Tensor TF_AllocateTensor(TFDataType dataType, long[] dims, int num_dims, size_t len); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Tensor TF_AllocateTensor(TFDataType dataType, IntPtr zeroDim, int num_dims, size_t len); + + /// + /// Low-level: Creates an empty tensor of the specified type and shape, with the specified number of elements + /// + /// Data type. + /// Tensor shape. + /// Size in bytes of the tensor, this will be the actual memory allocated. + /// + /// It is the responsibility of the caller to ensure that the size is correct given the data type size + /// and the tensor dimension specified in dims. + /// + public TFTensor(TFDataType dataType, long[] dims, int size) : base(IntPtr.Zero) + { + if (dims == null) + throw new ArgumentNullException("dims"); + handle = TF_AllocateTensor(dataType, dims, dims.Length, (size_t)size); + } + + // extern void TF_DeleteTensor (TF_Tensor *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteTensor(TF_Tensor tensor); + + // extern TF_DataType TF_TensorType (const TF_Tensor *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TFDataType TF_TensorType(TF_Tensor tensor); + + /// + /// Returns the data type for the tensor. + /// + /// The type of the tensor. + public TFDataType TensorType => TF_TensorType(handle); + + // extern int TF_NumDims (const TF_Tensor *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe int TF_NumDims(TF_Tensor tensor); + + /// + /// Returns the number of dimensions in the tensor. + /// + /// + /// For single-dimension tensors the return is 1, 2 dimensions is 2 and so on. + /// + public int NumDims => TF_NumDims(handle); + + // extern int64_t TF_Dim (const TF_Tensor *tensor, int dim_index); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe long TF_Dim(TF_Tensor tensor, int dim_index); + + /// + /// Returns the number of elements on a specific dimension in the tensor. + /// + /// The tensor dimension. + /// Dimension that you are querying. + /// + /// If you have a tensor of 3 elements by 5, represented by [3 5], + /// the GetTensorDimension(0) will return 3, the GetTensorDimension(1) + /// will return 5. + /// + public long GetTensorDimension(int dimIndex) + { + return TF_Dim(handle, dimIndex); + } + + // extern size_t TF_TensorByteSize (const TF_Tensor *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe size_t TF_TensorByteSize(TF_Tensor tensor); + + public size_t TensorByteSize => TF_TensorByteSize(handle); + + // extern void * TF_TensorData (const TF_Tensor *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe IntPtr TF_TensorData(TF_Tensor tensor); + + /// + /// Returns a pointer to the raw data in the tensor. + /// + /// + /// The contents of the Data must be interpreted according to the type of the + /// data as described by the DataType property. The amount of data + /// is given by the the TensorByteSize property. + /// + public IntPtr Data => TF_TensorData(handle); + + /// + /// Returns the tensor shape, this is an array whose size determines the number of dimensions on the tensor, and each element is the size of the dimension + /// + /// + /// An array of size 0 is used for constants, an array of size 1 is used + /// for single-dimension arrays, where the dimension is the value of the + /// first element. And so on. + /// + public long[] Shape + { + get + { + var dims = new long[TF_NumDims(handle)]; + for (int i = 0; i < dims.Length; i++) + dims[i] = (int)TF_Dim(handle, i); + + return dims; + } + } + + /// + /// Converts a to a system type. + /// + /// The to be converted. + /// The system type corresponding to the given . + public static Type TypeFromTensorType(TFDataType type) + { + switch (type) + { + case TFDataType.Float: + return typeof(float); + case TFDataType.Double: + return typeof(double); + case TFDataType.Int32: + return typeof(int); + case TFDataType.UInt8: + return typeof(byte); + case TFDataType.Int16: + return typeof(short); + case TFDataType.Int8: + return typeof(sbyte); + case TFDataType.String: + throw new NotSupportedException(); + case TFDataType.Int64: + return typeof(long); + case TFDataType.Bool: + return typeof(bool); + case TFDataType.UInt16: + return typeof(ushort); + case TFDataType.Complex128: + return typeof(Complex); + default: + return null; + } + } + + /// + /// Converts a system type to a . + /// + /// The system type to be converted. + /// The corresponding to the given type. + public static TFDataType TensorTypeFromType(Type type) + { + if (type == typeof(float)) + return TFDataType.Float; + if (type == typeof(double)) + return TFDataType.Double; + if (type == typeof(int)) + return TFDataType.Int32; + if (type == typeof(byte)) + return TFDataType.UInt8; + if (type == typeof(short)) + return TFDataType.Int16; + if (type == typeof(sbyte)) + return TFDataType.Int8; + if (type == typeof(string)) + return TFDataType.String; + if (type == typeof(long)) + return TFDataType.Int64; + if (type == typeof(bool)) + return TFDataType.Bool; + if (type == typeof(ushort)) + return TFDataType.UInt16; + if (type == typeof(Complex)) + return TFDataType.Complex128; + + throw new ArgumentOutOfRangeException(nameof(type), $"The given type could not be mapped to an existing {nameof(TFDataType)}."); + } + + private static unsafe object FetchSimple(TFDataType dt, IntPtr data) + { + switch (dt) + { + case TFDataType.Float: + return *(float*)data; + case TFDataType.Double: + return *(double*)data; + case TFDataType.Int32: + return *(int*)data; + case TFDataType.UInt8: + return *(byte*)data; + case TFDataType.Int16: + return *(short*)data; + case TFDataType.Int8: + return *(sbyte*)data; + case TFDataType.String: + throw new NotImplementedException(); + case TFDataType.Int64: + return *(long*)data; + case TFDataType.Bool: + return *(bool*)data; + case TFDataType.UInt16: + return *(ushort*)data; + case TFDataType.Complex128: + return *(Complex*)data; + default: + return null; + } + } + + internal static unsafe void Copy(IntPtr src, void* target, int size) + { + Buffer.MemoryCopy((void*)src, target, size, size); + } + + internal static unsafe void FetchFlatArray(Array target, TFDataType dt, IntPtr data) + { + int len = target.Length; + switch (dt) + { + case TFDataType.Int8: + var asbyte = (sbyte[])target; + fixed (sbyte* p = &asbyte[0]) + Copy(data, p, len); + return; + case TFDataType.Bool: + var abool = (bool[])target; + fixed (bool* p = &abool[0]) + Copy(data, p, len); + return; + case TFDataType.UInt16: + var aushort = (ushort[])target; + fixed (ushort* p = &aushort[0]) + Copy(data, p, len * 2); + return; + case TFDataType.Complex128: + var acomplex = (Complex[])target; + fixed (Complex* p = &acomplex[0]) + Copy(data, p, len * sizeof(Complex)); + return; + case TFDataType.Float: + var afloat = (float[])target; + fixed (float* p = &afloat[0]) + Copy(data, p, len * sizeof(float)); + return; + case TFDataType.Double: + var adouble = (double[])target; + fixed (double* p = &adouble[0]) + Copy(data, p, len * sizeof(double)); + return; + case TFDataType.Int32: + var aint = (int[])target; + fixed (int* p = &aint[0]) + Copy(data, p, len * sizeof(double)); + return; + case TFDataType.UInt8: + var abyte = (byte[])target; + fixed (byte* p = &abyte[0]) + Copy(data, p, len * sizeof(byte)); + return; + case TFDataType.Int16: + var ashort = (short[])target; + fixed (short* p = &ashort[0]) + Copy(data, p, len * sizeof(short)); + return; + case TFDataType.Int64: + var along = (long[])target; + fixed (long* p = &along[0]) + Copy(data, p, len * sizeof(long)); + return; + case TFDataType.String: + // need to return an array of TFStrings [] + throw new NotImplementedException(); + default: + throw new NotImplementedException(); + } + } + + private static unsafe object FetchJaggedArray(Type t, TFDataType dt, ref IntPtr data, long[] shape, int level = 0) + { + Array target; + + // If we are at the last node + if (level == shape.Length - 1) + { + target = Array.CreateInstance(t, shape[level]); + + for (long l = 0; l < shape[level]; l++) + switch (dt) + { + case TFDataType.Float: + target.SetValue((*(float*)data), l); + data += 4; + break; + case TFDataType.Double: + target.SetValue((*(double*)data), l); + data += 8; + break; + case TFDataType.Int32: + target.SetValue((*(int*)data), l); + data += 4; + break; + case TFDataType.UInt8: + target.SetValue((*(byte*)data), l); + data += 1; + break; + case TFDataType.Int16: + target.SetValue((*(short*)data), l); + data += 2; + break; + case TFDataType.Int8: + target.SetValue((*(sbyte*)data), l); + data += 1; + break; + case TFDataType.Int64: + target.SetValue((*(long*)data), l); + data += 8; + break; + case TFDataType.Bool: + target.SetValue((*(bool*)data), l); + data += 1; + break; + case TFDataType.Complex128: + target.SetValue((*(Complex*)data), l); + data += sizeof(Complex); + break; + case TFDataType.String: + throw new NotImplementedException("String decoding not implemented for tensor vecotrs yet"); + default: + throw new NotImplementedException(); + } + } + else + { + target = null; + + long top = shape[level]; + if (top < Int32.MaxValue) + { + int itop = (int)top; + + for (int i = 0; i < itop; i++) + { + var childArray = FetchJaggedArray(t, dt, ref data, shape, level + 1); + if (target == null) + target = Array.CreateInstance(childArray.GetType(), shape[level]); + + target.SetValue(childArray, i); + } + } + else + { + for (long l = 0; l < top; l++) + { + + var chidArray = FetchJaggedArray(t, dt, ref data, shape, level + 1); + if (target == null) + target = Array.CreateInstance(chidArray.GetType(), shape[level]); + + target.SetValue(chidArray, l); + } + } + return target; + } + + return target; + } + + private static void FetchMultiDimensionalArray(Array target, TFDataType dt, IntPtr data, long[] shape) + { + var idx = new int[shape.Length]; + for (int i = 0; i < shape.Length; i++) + { + if (shape[i] > Int32.MaxValue) + throw new ArgumentOutOfRangeException("Shape can not be longer than 32 bits"); + } + Copy(target, dt, shape, idx, 0, ref data); + } + + private static unsafe void Copy(Array target, TFDataType dt, long[] shape, int[] idx, int level, ref IntPtr data) + { + if (level < shape.Length - 1) + { + for (idx[level] = 0; idx[level] < shape[level]; idx[level]++) + Copy(target, dt, shape, idx, level + 1, ref data); + } + else + { + for (idx[level] = 0; idx[level] < shape[level]; idx[level]++) + { + switch (dt) + { + case TFDataType.Float: + target.SetValue((*(float*)data), idx); + data += 4; + break; + case TFDataType.Double: + target.SetValue((*(double*)data), idx); + data += 8; + break; + case TFDataType.Int32: + target.SetValue((*(int*)data), idx); + data += 4; + break; + case TFDataType.UInt8: + target.SetValue((*(byte*)data), idx); + data += 1; + break; + case TFDataType.Int16: + target.SetValue((*(short*)data), idx); + data += 2; + break; + case TFDataType.Int8: + target.SetValue((*(sbyte*)data), idx); + data += 1; + break; + case TFDataType.Int64: + target.SetValue((*(long*)data), idx); + data += 8; + break; + case TFDataType.Bool: + target.SetValue((*(bool*)data), idx); + data += 1; + break; + case TFDataType.Complex128: + target.SetValue((*(Complex*)data), idx); + data += sizeof(Complex); + break; + case TFDataType.String: + throw new NotImplementedException("String decoding not implemented for tensor vecotrs yet"); + default: + throw new NotImplementedException(); + } + } + } + } + + /// + /// Returns the value of the Tensor as a C# type if possible, or null if the data type can not be represented in C# + /// + /// + /// The default is set to false, which returns .NET multi-dimensional arrays for multi-dimensional + /// tensors. This is useful to feed the data back as a TFTensor created from an array. Set to + /// true if you want to get arrays pointing to arrays, which are slightly more convenient to work + /// with from C# + /// + /// + /// Jagged arrays create various intermediate arrays, while multi-dimensional arrays are more + /// efficient memory-wise. + /// + /// The value encodes the contents of the tensor, and could include simple values, arrays and multi-dimensional values. + public object GetValue(bool jagged = false) + { + var dims = NumDims; + if (dims == 0) + return FetchSimple(TensorType, Data); + + var t = TypeFromTensorType(TensorType); + if (t == null) + return null; + + if (dims == 1) + { + var result = Array.CreateInstance(t, Shape[0]); + FetchFlatArray(result, TensorType, Data); + return result; + } + else + { + if (jagged) + { + IntPtr data = Data; + return FetchJaggedArray(t, TensorType, ref data, Shape); + } + else + { + var result = Array.CreateInstance(t, Shape); + FetchMultiDimensionalArray(result, TensorType, Data, Shape); + return result; + } + } + } + + /// + /// Returns a that represents the current . + /// + /// A that represents the current . + public override string ToString() + { + var n = NumDims; + if (n == 0) + return GetValue().ToString(); + + StringBuilder sb = new StringBuilder("["); + for (int i = 0; i < n; i++) + { + sb.Append(TF_Dim(handle, i)); + if (i + 1 < n) + sb.Append("x"); + } + sb.Append("]"); + return sb.ToString(); + } + + } + +} diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.cs new file mode 100644 index 0000000000..12e2bea004 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.cs @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.ML.Transforms.TensorFlow +{ + internal partial class TFTensor + { + /// + /// Creates a tensor representing type T. + /// The tensor will be backed with a managed-heap-allocated T. + /// + /// .NET type of tensor to create + /// value of tensor + public static TFTensor CreateScalar(T data) + { + if (typeof(T) == typeof(System.Boolean)) + { + return new TFTensor((System.Boolean)(object)data); + } + else if (typeof(T) == typeof(System.Byte)) + { + return new TFTensor((System.Byte)(object)data); + } + else if (typeof(T) == typeof(System.Char)) + { + return new TFTensor((System.Char)(object)data); + } + else if (typeof(T) == typeof(System.Numerics.Complex)) + { + return new TFTensor((System.Numerics.Complex)(object)data); + } + else if (typeof(T) == typeof(System.Double)) + { + return new TFTensor((System.Double)(object)data); + } + else if (typeof(T) == typeof(System.Single)) + { + return new TFTensor((System.Single)(object)data); + } + else if (typeof(T) == typeof(System.Int32)) + { + return new TFTensor((System.Int32)(object)data); + } + else if (typeof(T) == typeof(System.Int64)) + { + return new TFTensor((System.Int64)(object)data); + } + else if (typeof(T) == typeof(System.SByte)) + { + return new TFTensor((System.SByte)(object)data); + } + else if (typeof(T) == typeof(System.Int16)) + { + return new TFTensor((System.Int16)(object)data); + } + else if (typeof(T) == typeof(System.UInt32)) + { + return new TFTensor((System.UInt32)(object)data); + } + else if (typeof(T) == typeof(System.UInt64)) + { + return new TFTensor((System.UInt64)(object)data); + } + else if (typeof(T) == typeof(System.UInt16)) + { + return new TFTensor((System.UInt16)(object)data); + } + throw new NotSupportedException($"Unsupported type {typeof(T)}"); + } + + /// + /// Creates a tensor representing type T[]. + /// T[] will be pinned and wrapped in a tensor. + /// + /// .NET type of tensor to create + /// value of tensor + /// shape of tensor + public static TFTensor Create(T[] data, TFShape shape) + { + if (typeof(T) == typeof(System.Boolean)) + { + return new TFTensor(SetupTensor(TFDataType.Bool, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4)); + } + else if (typeof(T) == typeof(System.Byte)) + { + return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1)); + } + else if (typeof(T) == typeof(System.Char)) + { + return new TFTensor(SetupTensor(TFDataType.UInt8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1)); + } + else if (typeof(T) == typeof(System.Numerics.Complex)) + { + return new TFTensor(SetupTensor(TFDataType.Complex128, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 16)); + } + else if (typeof(T) == typeof(System.Double)) + { + return new TFTensor(SetupTensor(TFDataType.Double, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8)); + } + else if (typeof(T) == typeof(System.Single)) + { + return new TFTensor(SetupTensor(TFDataType.Float, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4)); + } + else if (typeof(T) == typeof(System.Int32)) + { + return new TFTensor(SetupTensor(TFDataType.Int32, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4)); + } + else if (typeof(T) == typeof(System.Int64)) + { + return new TFTensor(SetupTensor(TFDataType.Int64, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8)); + } + else if (typeof(T) == typeof(System.SByte)) + { + return new TFTensor(SetupTensor(TFDataType.Int8, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 1)); + } + else if (typeof(T) == typeof(System.Int16)) + { + return new TFTensor(SetupTensor(TFDataType.Int16, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 2)); + } + else if (typeof(T) == typeof(System.UInt32)) + { + return new TFTensor(SetupTensor(TFDataType.UInt32, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 4)); + } + else if (typeof(T) == typeof(System.UInt64)) + { + return new TFTensor(SetupTensor(TFDataType.UInt64, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 8)); + } + else if (typeof(T) == typeof(System.UInt16)) + { + return new TFTensor(SetupTensor(TFDataType.UInt16, shape, (Array)(object)data, 0, ((Array)(object)data).Length, 2)); + } + // note that we will get here for jagged arrays, which is intententional since we'd need to copy them. + throw new NotSupportedException($"Unsupported type {typeof(T)}"); + } + } +} + diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.tt b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.tt new file mode 100644 index 0000000000..b2f44697c3 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorGeneric.tt @@ -0,0 +1,99 @@ +<#@ template debug="false" hostspecific="false" language="C#" #> +<#@ assembly name="System.Core" #> +<#@ assembly name="System.Numerics" #> +<#@ import namespace="System.Linq" #> +<#@ import namespace="System.Text" #> +<#@ import namespace="System.Collections.Generic" #> +<#@ import namespace="System.Runtime.InteropServices" #> +<#@ output extension=".cs" #>// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System; + +namespace Microsoft.ML.Transforms.TensorFlow +{ + internal partial class TFTensor + { + /// + /// Creates a tensor representing type T. + /// The tensor will be backed with a managed-heap-allocated T. + /// + /// .NET type of tensor to create + /// value of tensor + public static TFTensor CreateScalar(T data) + { +<# foreach (TypeConfiguration type in typeConfiguration) { #> + <#=GenerateIfStatementHeader(type)#> + { + return new TFTensor((<#=type.TypeName#>)(object)data); + } +<# } #> + throw new NotSupportedException($"Unsupported type {typeof(T)}"); + } + + /// + /// Creates a tensor representing type T[]. + /// T[] will be pinned and wrapped in a tensor. + /// + /// .NET type of tensor to create + /// value of tensor + /// shape of tensor + public static TFTensor Create(T[] data, TFShape shape) + { +<# foreach (TypeConfiguration type in typeConfiguration) { #> + <#=GenerateIfStatementHeader(type)#> + { + return new TFTensor(SetupTensor(TFDataType.<#=type.TFDataType#>, shape, (Array)(object)data, 0, ((Array)(object)data).Length, <#=type.Size#>)); + } +<# } #> + // note that we will get here for jagged arrays, which is intententional since we'd need to copy them. + throw new NotSupportedException($"Unsupported type {typeof(T)}"); + } + } +} + +<#+ + public class TypeConfiguration + { + public TypeConfiguration(Type type, string tfDataType) + { + Type = type; + TFDataType = tfDataType; + } + public string TypeName + { + get { return Type.ToString(); } + } + public Type Type { get; } + public string TFDataType { get; } + public int Size + { + get { return Marshal.SizeOf(Type); } + } + } + + public string GenerateIfStatementHeader(TypeConfiguration type, string lhs = "typeof(T)") + { + string keyword = (type == typeConfiguration[0]) ? "if" : "else if"; + return $"{keyword} ({lhs} == typeof({type.TypeName}))"; + } + + public TypeConfiguration[] typeConfiguration = new [] + { + new TypeConfiguration(typeof(bool), "Bool"), + new TypeConfiguration(typeof(byte), "UInt8"), + new TypeConfiguration(typeof(char), "UInt8"), + new TypeConfiguration(typeof(System.Numerics.Complex), "Complex128"), + // new TypeConfiguration(typeof(decimal), "unknown"), TF doesn't appear to have 128-bit floating-point + new TypeConfiguration(typeof(double),"Double"), + new TypeConfiguration(typeof(float), "Float"), + new TypeConfiguration(typeof(int), "Int32"), + new TypeConfiguration(typeof(long), "Int64"), + new TypeConfiguration(typeof(sbyte), "Int8"), + new TypeConfiguration(typeof(short), "Int16"), + new TypeConfiguration(typeof(uint), "UInt32"), + new TypeConfiguration(typeof(ulong), "UInt64"), + new TypeConfiguration(typeof(ushort), "UInt16") + // TODO, map other types + }; +#> \ No newline at end of file diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs new file mode 100644 index 0000000000..f52756d4a5 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs @@ -0,0 +1,2184 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; +using System.Text; +using System.Globalization; +using System.Linq; + +// We use this TF_Xxx as the native "TF_Xxx *" as those are opaque +using TF_Status = System.IntPtr; +using TF_SessionOptions = System.IntPtr; +using TF_Graph = System.IntPtr; +using TF_OperationDescription = System.IntPtr; +using TF_Operation = System.IntPtr; +using TF_Session = System.IntPtr; +using TF_DeprecatedSession = System.IntPtr; +using TF_Tensor = System.IntPtr; +using TF_ImportGraphDefOptions = System.IntPtr; +using TF_Library = System.IntPtr; +using TF_BufferPtr = System.IntPtr; +using TF_Function = System.IntPtr; +using TF_DeviceList = System.IntPtr; + +using size_t = System.UIntPtr; +using System.Numerics; +using System.Collections.Generic; +using System.Linq.Expressions; + +#pragma warning disable MSML_GeneralName +#pragma warning disable MSML_PrivateFieldName +#pragma warning disable MSML_ParameterLocalVarName + +namespace Microsoft.ML.Transforms.TensorFlow +{ + internal static partial class NativeBinding + { + public const string TensorFlowLibrary = "tensorflow"; + public const string TensorFlowLibraryGPU = "libtensorflowgpu"; + + internal static string GetStr(this IntPtr x) => Marshal.PtrToStringAnsi(x); + } + + /// + /// Contains TensorFlow fundamental methods and utility functions. + /// + internal static class TFCore + { + internal static bool UseCPU = true; + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe IntPtr TF_Version(); + + static TFCore() + { + Init(); + } + + internal static void Init() + { + CheckSize(); + } + + /// + /// Returns the version of the TensorFlow runtime in use. + /// + /// The version. + public static string Version => TF_Version().GetStr(); + + // extern size_t TF_DataTypeSize (TF_DataType dt); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern IntPtr TF_DataTypeSize(TFDataType dt); + + /// + /// Gets the size in bytes of the specified TensorFlow data type. + /// + /// The data type size. + /// Dt. + public static long GetDataTypeSize(TFDataType dt) => (long)TF_DataTypeSize(dt); + + // extern TF_Buffer * TF_GetAllOpList (); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe IntPtr TF_GetAllOpList(); + + /// + /// Retrieves the ProtocolBuffer describing all of the available operations in + /// the TensorFlow library in current use. + /// + /// The buffer contains a ProtocolBuffer encoded payload, you need a ProtocolBuffer reader to process the contents. + public static TFBuffer GetAllOpList() + { + return new TFBuffer(TF_GetAllOpList()); + } + + private static void CheckSize() + { + unsafe + { + if (sizeof(IntPtr) == 4) + { + Console.Error.WriteLine( + "The TensorFlow native libraries were compiled in 64 bit mode, you must run in 64 bit mode\n" + + "With Mono, do that with mono --arch=64 executable.exe, if using an IDE like MonoDevelop,\n" + + "Xamarin Studio or Visual Studio for Mac, Build/Compiler settings, make sure that " + + "\"Platform Target\" has x64 selected."); + throw new Exception(); + + } + } + } + } + + /// + /// Base class for many TensorFlow data types that provides a common idiom to dispose and + /// release resources associated with the native data types. Generally, you do not need to use this. + /// + /// + /// + /// This implements the Dispose pattern in a reusable form for TensorFlow types. + /// + /// + /// Subclasses invoke the constructor with the handle that this will wrap, and must + /// override the NativeDispose method (internal) to release the associated resource. + /// + /// + internal abstract class TFDisposable : IDisposable + { + internal IntPtr handle; + + /// + /// Returns the opaque handle to the object that this TFDisposable owns. + /// + /// The handle. + public IntPtr Handle => handle; + + static TFDisposable() + { + TFCore.Init(); + } + + /// + /// Initializes a new instance of the class. + /// + public TFDisposable() + { } + + /// + /// Initializes a new instance of the class + /// from the handle that it will wrap. + /// + public TFDisposable(IntPtr handle) + { + this.handle = handle; + } + + /// + /// Releases all resource used by the object. + /// + /// Call Dispose when you are finished using the . The + /// Dispose method leaves the in an unusable state. After + /// calling Dispose, you must release all references to the so + /// the garbage collector can reclaim the memory that the was occupying. + public void Dispose() + { + Dispose(true); + GC.SuppressFinalize(this); + } + + ~TFDisposable() + { + Dispose(false); + } + + // Must be implemented in subclasses to dispose the unmanaged object, it does + // not need to take care of zeroing out the handle, that is done by the Dispose + // method inherited from TFDisposable + internal abstract void NativeDispose(IntPtr handle); + + /// + /// Dispose the specified object + /// + /// If set to true it means that this method was called from Dispose, otherwise from the finalizer. + public virtual void Dispose(bool disposing) + { + if (disposing) + { + if (handle != IntPtr.Zero) + NativeDispose(handle); + handle = IntPtr.Zero; + } + } + + internal static void ObjectDisposedException() + { + throw new ObjectDisposedException("The object was disposed"); + } + } + + /// + /// ase class for many TensorFlow data types that provides a common idiom to dispose and + /// release resources associated with the native data types and whose unmanaged resource + /// disposing can be called from a background thread (the finalizer). Users do not + /// need to deal with this class. + /// + /// + /// Some object deletion APIs in TensorFlow can be invoked from a background thread, + /// so the release methods are suitable to be invoked from the Finalizer thread, in + /// those scenarios, subclass from this class rather than the TFDisposable class. + /// + internal abstract class TFDisposableThreadSafe : TFDisposable + { + /// + /// Initializes a new instance of the class + /// from the handle that it will wrap. + /// + public TFDisposableThreadSafe(IntPtr handle) : base(handle) + { + } + + /// + /// Initializes a new instance of the class. + /// + public TFDisposableThreadSafe() + { } + + /// + /// Dispose the object, unlike the default implementat in TFDisposable, + /// this will release the unmanaged resources from a background thread. + /// + /// If set to true disposing. + public override void Dispose(bool disposing) + { + if (handle != IntPtr.Zero) + NativeDispose(handle); + handle = IntPtr.Zero; + } + } + + /// + /// TensorFlow Exception + /// + internal class TFException : Exception + { + /// + /// Initializes a new instance of the class with a message. + /// + /// Message. + public TFException(string message) : base(message) { } + } + + /// + /// Used to track the result of TensorFlow operations. + /// + /// + /// + /// TFStatus is used to track the status of a call to some TensorFlow + /// operations. Instances of this object are passed to various + /// TensorFlow operations and you can use the + /// to quickly check if the operation succeeded, or get more detail from the + /// and a human-readable text + /// using the property. + /// + /// + /// The convenience can be used + /// to raise a if the status of the + /// operation did not succeed. + /// + /// + internal class TFStatus : TFDisposable + { + // extern TF_Status * TF_NewStatus (); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Status TF_NewStatus(); + + /// + /// Per-thread global status that you can use if you do not need to create a new instance of this object. + /// + /// + /// This is provided as a convenience for APIs that take a TFStatus. While the TFStatus is usually an + /// optional parameter, when it is made optional, API calls that fail raise an exception. Use this + /// property to pass a TFStatus without having to allocate a new one. The problem with this of course + /// is that you risk having multiple parts of your code override this thread-global variable. + /// + [ThreadStatic] public static TFStatus Default = new TFStatus(); + + /// + /// Initializes a new instance of the class. + /// + public TFStatus() : base(TF_NewStatus()) + { + } + + // extern void TF_DeleteStatus (TF_Status *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteStatus(TF_Status status); + + internal override void NativeDispose(IntPtr handle) + { + TF_DeleteStatus(handle); + } + + // extern void TF_SetStatus (TF_Status *s, TF_Code code, const char *msg); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_SetStatus(TF_Status s, TFCode code, string msg); + + /// + /// Sets the status code on this TFStatus. + /// + /// Code. + /// Message. + public void SetStatusCode(TFCode code, string msg) + { + TF_SetStatus(handle, code, msg); + } + + // extern TF_Code TF_GetCode (const TF_Status *s); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TFCode TF_GetCode(TF_Status s); + + /// + /// Gets the status code for the status code. + /// + /// The status code as an enumeration. + public TFCode StatusCode + { + get + { + if (handle == IntPtr.Zero) + throw new ObjectDisposedException("TFStatus"); + return TF_GetCode(handle); + } + } + + // extern const char * TF_Message (const TF_Status *s); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe IntPtr TF_Message(TF_Status s); + + /// + /// Gets a human-readable status message. + /// + /// The status message. + public string StatusMessage => TF_Message(handle).GetStr(); + + /// + /// Returns a that represents the current . + /// + /// A that represents the current . + public override string ToString() + { + if (handle == IntPtr.Zero) + throw new ObjectDisposedException("TFStatus"); + + return string.Format("[TFStatus: StatusCode={0}, StatusMessage={1}]", StatusCode, StatusMessage); + } + + /// + /// Gets a value indicating whether this state has been set to ok. + /// + /// true if ok; otherwise, false. + public bool Ok => StatusCode == TFCode.Ok; + + /// + /// Gets a value indicating whether this state has been set to an error. + /// + /// true if error; otherwise, false. + public bool Error => StatusCode != TFCode.Ok; + + /// + /// Convenience method that raises an exception if the current status is an error. + /// + /// + /// You can use this method as a convenience to raise an exception after you + /// invoke an operation if the operation did not succeed. + /// + public void Raise() + { + if (TF_GetCode(handle) != TFCode.Ok) + throw new TFException(StatusMessage); + } + + // + // Utility function used to simplify implementing the idiom + // where the user optionally provides a TFStatus, if it is provided, + // the error is returned there; If it is not provided, then an + // exception is raised. + // + + internal bool CheckMaybeRaise(TFStatus incomingStatus, bool last = true) + { + if (incomingStatus == null) + { + if (handle == IntPtr.Zero) + Console.WriteLine("oops"); + if (StatusCode != TFCode.Ok) + { + var e = new TFException(StatusMessage); + if (last) + Dispose(); + throw e; + } + if (last) + Dispose(); + return true; + } + return StatusCode == TFCode.Ok; + } + + internal static TFStatus Setup(TFStatus incoming) + { + return incoming == null ? new TFStatus() : incoming; + } + } + + /// + /// The session options object holds configuration options that you want to use during your session, like the TensorFlow target or the configuration. + /// + internal class TFSessionOptions : TFDisposable + { + // extern TF_SessionOptions * TF_NewSessionOptions (); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_SessionOptions TF_NewSessionOptions(); + + /// + /// Initializes a new instance of the class. + /// + public TFSessionOptions() : base(TF_NewSessionOptions()) { } + + // extern void TF_DeleteSessionOptions (TF_SessionOptions *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteSessionOptions(TF_SessionOptions options); + internal override void NativeDispose(IntPtr handle) + { + TF_DeleteSessionOptions(handle); + } + + // extern void TF_SetTarget (TF_SessionOptions *options, const char *target); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_SetTarget(TF_SessionOptions options, string target); + + /// + /// Sets the target in options. + /// + /// target can be empty, a single entry, or a comma separated list of entries. + /// Each entry is in one of the following formats: "local", ip:port, host:port. + /// + public void SetTarget(string target) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + + TF_SetTarget(handle, target); + } + + // extern void TF_SetConfig (TF_SessionOptions *options, const void *proto, size_t proto_len, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_SetConfig(TF_SessionOptions options, IntPtr proto, size_t proto_len, TF_Status status); + + /// + /// Sets the configuration information for the session. + /// + /// Serialized protocol buffer for the tensorflow.ConfigProto message. + /// Length of the buffer. + /// If config was not parsed successfully as a ConfigProto, the error is recorded here. + /// + /// The configuration option is a Protocol Buffer representing the tensorflow.ConfigProto + /// + public void SetConfig(IntPtr protoData, int length, TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + + var cstatus = TFStatus.Setup(status); + + TF_SetConfig(handle, protoData, (UIntPtr)length, cstatus.handle); + cstatus.CheckMaybeRaise(status); + } + + } + + /// + /// Represents a computation graph. Graphs may be shared between sessions and are thread safe. + /// + /// + /// + /// Graphs consist of operations (represented by TFOperation objects), these can be named, or + /// the runtime will automatically assign a name. + /// + /// + /// For debugging purposes, you might want to group operations together, for this, call the + /// WithScope method with your new scope, which will create a new namespace for your object names. + /// + /// + /// For example, if you call WithScope ("demo"), and add an operation named "add" inside the + /// scope, the full name of the operation will be "demo/add", if you create a new scope inside, say + /// "hot", and add a "sub" operation there the result will be "demo/hot/sub". + /// + /// + internal partial class TFGraph : TFDisposableThreadSafe + { + // extern TF_Graph * TF_NewGraph (); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Graph TF_NewGraph(); + + /// + /// Initializes a new instance of the class. + /// + public TFGraph() : base(TF_NewGraph()) + { + } + + // extern void TF_DeleteGraph (TF_Graph *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteGraph(TF_Graph graph); + internal override void NativeDispose(IntPtr handle) + { + TF_DeleteGraph(handle); + } + + // extern int TF_GraphGetTensorNumDims (TF_Graph *graph, TF_Output output, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe int TF_GraphGetTensorNumDims(TF_Graph graph, TFOutput output, TF_Status status); + + // extern void TF_GraphGetTensorShape (TF_Graph *graph, TF_Output output, int64_t *dims, int num_dims, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_GraphGetTensorShape(TF_Graph graph, TFOutput output, long[] dims, int num_dims, TF_Status status); + + /// + /// Returns the shape of a tensor specified in . + /// + /// + /// The tensor shape. If the number of dimensions in the shape is unknown or the shape is, a scalar, the values in the array will be zero. Otherwise, each element of will be set corresponding to the size of the dimension. An unknown dimension is represented by -1. + /// The tensor that you want to look up. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public TFShape GetTensorShape(TFOutput output, TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + var cstatus = TFStatus.Setup(status); + var n = TF_GraphGetTensorNumDims(handle, output, cstatus.handle); + if (!cstatus.CheckMaybeRaise(status, last: false)) + return TFShape.Unknown; + if (n == -1) + return TFShape.Unknown; + + var dims = new long[n]; + TF_GraphGetTensorShape(handle, output, dims, dims.Length, cstatus.handle); + cstatus.CheckMaybeRaise(status); + return new TFShape(dims); + } + + // extern void TF_GraphToGraphDef (TF_Graph *graph, TF_Buffer *output_graph_def, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_GraphToGraphDef(TF_Graph graph, LLBuffer* output_graph_def, TF_Status status); + + /// + /// Write out a serialized representation of the graph (as a GraphDef protocol buffer message) into . + /// + /// Target buffer where the graphs is serialized into. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public void ToGraphDef(TFBuffer outputGraphDef, TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (outputGraphDef == null) + throw new ArgumentNullException(nameof(outputGraphDef)); + + var cstatus = TFStatus.Setup(status); + unsafe + { + TF_GraphToGraphDef(handle, outputGraphDef.LLBuffer, cstatus.handle); + } + cstatus.CheckMaybeRaise(status); + } + + // extern void TF_GraphImportGraphDef (TF_Graph *graph, const TF_Buffer *graph_def, const TF_ImportGraphDefOptions *options, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_GraphImportGraphDef(TF_Graph graph, LLBuffer* graph_def, TF_ImportGraphDefOptions options, TF_Status status); + + /// + /// Import a serialized graph into this graph, using the specified prefix. + /// + /// The import. + /// A buffer containing the serialized graph. + /// A prefix that will be prepended to names of nodes in the when they are imported into the graph. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public void Import(TFBuffer graphDef, string prefix = "", TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (graphDef == null) + throw new ArgumentNullException(nameof(graphDef)); + if (prefix == null) + throw new ArgumentNullException(nameof(prefix)); + + using (var options = new TFImportGraphDefOptions()) + { + options.SetPrefix(prefix); + Import(graphDef, options, status); + } + } + + /// + /// Import a serialized graph into this graph, using the specified importing options. + /// + /// The import. + /// A buffer containing the serialized graph. + /// Importing graph options. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public void Import(TFBuffer graphDef, TFImportGraphDefOptions options, TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (graphDef == null) + throw new ArgumentNullException(nameof(graphDef)); + if (options == null) + throw new ArgumentNullException(nameof(options)); + + var cstatus = TFStatus.Setup(status); + unsafe + { + TF_GraphImportGraphDef(handle, graphDef.LLBuffer, options.handle, cstatus.handle); + } + cstatus.CheckMaybeRaise(status); + } + + /// + /// Import a serialized graph held in a byte array into this graph, using the specified prefix. + /// + /// The import. + /// A byte array containing the serialized graph. + /// A prefix that will be prepended to names of nodes in the graph when they are imported into the graph. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public void Import(byte[] buffer, string prefix = "", TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (prefix == null) + throw new ArgumentNullException(nameof(prefix)); + using (var options = new TFImportGraphDefOptions()) + { + options.SetPrefix(prefix); + Import(buffer, options, status); + } + } + + /// + /// Import a serialized graph held in a byte array into this graph, using the specified import options. + /// + /// The import. + /// A byte array containing the serialized graph. + /// Importing graph options. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// If you are tryig to load a file stored using the SavedModel file format, you should use the API instead. + /// + public void Import(byte[] buffer, TFImportGraphDefOptions options, TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (buffer == null) + throw new ArgumentNullException(nameof(buffer)); + if (options == null) + throw new ArgumentNullException(nameof(options)); + var cstatus = TFStatus.Setup(status); + using (var tb = new TFBuffer(buffer, 0, buffer.Length)) + Import(tb, options, status); + + cstatus.CheckMaybeRaise(cstatus); + } + + // extern TF_Operation * TF_GraphOperationByName (TF_Graph *graph, const char *oper_name); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Operation TF_GraphOperationByName(TF_Graph graph, string oper_name); + + /// + /// Gets the with the specified name, or null if the named operation does not exist in the graph. + /// + /// Name to lookup. + public TFOperation this[string name] + { + get + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + var h = TF_GraphOperationByName(handle, name); + if (h == IntPtr.Zero) + return null; + return new TFOperation(this, h); + } + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern string TF_GraphDebugString(TF_Graph graph, out IntPtr len); + + public override string ToString() + { + IntPtr len; + return TF_GraphDebugString(Handle, out len); + } + } + + /// + /// Represents a computation node in the graph. Tensorflow operations are attached to a . + /// + /// + /// TFOperations are usually created by invoking one of the methods in + /// , but they can also be constructed + /// manually using the low-level API. + /// + internal partial class TFOperation + { + internal IntPtr handle; + + /// + /// Gets the handle to the unmanaged TF_Operation object. + /// + /// The handle. + public IntPtr Handle => handle; + + // Pointer to the graph, to keep it from collecting if there are TFOperations alive. + internal TFGraph graph; + + internal TFOperation(TFGraph graph, IntPtr handle) + { + this.handle = handle; + this.graph = graph; + } + + /// + /// Returns the handle to the idx-th output of the operation. + /// + /// Index of the output in the operation. + public TFOutput this[int idx] + { + get + { + return new TFOutput(this, idx); + } + } + } + + /// + /// Device type + /// + internal enum DeviceType + { + /// + /// The device is the Central Processing Unit (CPU) + /// + CPU, + + /// + /// The device is a Graphics Processing Unit (GPU) + /// + GPU, + + /// + /// The device is a Tensor Processing Unit (TPU) + /// + TPU + } + + /// + /// Describes the device attributes + /// + internal class DeviceAttributes + { + internal DeviceAttributes(string name, DeviceType deviceType, long memoryLimitBytes) + { + Name = name; + DeviceType = deviceType; + MemoryLimitBytes = memoryLimitBytes; + } + + /// + /// The full name of the device (e.g. /job:worker/replica:0/...) + /// + public string Name { get; private set; } + + /// + /// Gets the type of the device. + /// + /// The type of the device. + public DeviceType DeviceType { get; private set; } + + /// + /// The amount of memory associated with a given device. + /// + /// The memory limit bytes. + public long MemoryLimitBytes { get; private set; } + } + + /// + /// Contains options that are used to control how graph importing works. + /// + internal class TFImportGraphDefOptions : TFDisposable + { + // extern TF_ImportGraphDefOptions * TF_NewImportGraphDefOptions (); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_ImportGraphDefOptions TF_NewImportGraphDefOptions(); + + public TFImportGraphDefOptions() : base(TF_NewImportGraphDefOptions()) + { + } + + // extern void TF_DeleteImportGraphDefOptions (TF_ImportGraphDefOptions *opts); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteImportGraphDefOptions(TF_ImportGraphDefOptions opts); + + internal override void NativeDispose(IntPtr handle) + { + TF_DeleteImportGraphDefOptions(handle); + } + + // extern void TF_ImportGraphDefOptionsSetPrefix (TF_ImportGraphDefOptions *opts, const char *prefix); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions opts, string prefix); + + public void SetPrefix(string prefix) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + TF_ImportGraphDefOptionsSetPrefix(handle, prefix); + } + + // extern void TF_ImportGraphDefOptionsAddInputMapping (TF_ImportGraphDefOptions *opts, const char* src_name, int src_index, TF_Output dst); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions opts, string src_name, int src_index, TFOutput dst); + + /// + /// Adds an input mapping from a source name and index to a destination output + /// + /// Source name. + /// Source index (in the source). + /// Replacement value for the srcName:srcIndex. + /// + /// Set any imported nodes with input `src_name:src_index` to have that input + /// replaced with `dst`. `src_name` refers to a node in the graph to be imported, + /// `dst` references a node already existing in the graph being imported into. + /// + public void AddInputMapping(string srcName, int srcIndex, TFOutput dst) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + TF_ImportGraphDefOptionsAddInputMapping(handle, srcName, srcIndex, dst); + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern void TF_ImportGraphDefOptionsAddControlDependency(TF_ImportGraphDefOptions opts, TF_Operation oper); + + /// + /// Cause the imported graph to have a control dependency on the provided operation. + /// + /// This operation should exist in the graph being imported to. + public void AddControlDependency(TFOperation operation) + { + if (operation == null) + throw new ArgumentNullException(nameof(operation)); + if (handle == IntPtr.Zero) + ObjectDisposedException(); + + TF_ImportGraphDefOptionsAddControlDependency(handle, operation.handle); + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions opts, string oper_name, int index); + + /// + /// Add an output in the graph definition to be returned via the return outputs parameter. + /// + /// Operation name. + /// Operation index. + /// + /// If the output is remapped via an input + /// mapping, the corresponding existing tensor in graph will be returned. + /// + public void AddReturnOutput(string operName, int index) + { + if (operName == null) + throw new ArgumentNullException(nameof(operName)); + if (handle == IntPtr.Zero) + ObjectDisposedException(); + TF_ImportGraphDefOptionsAddReturnOutput(handle, operName, index); + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern int TF_ImportGraphDefOptionsNumReturnOutputs(TF_ImportGraphDefOptions opts); + + /// + /// Gets the number return outputs added via AddReturnOutput. + /// + /// The number return outputs. + public int NumReturnOutputs + { + get + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + return TF_ImportGraphDefOptionsNumReturnOutputs(handle); + } + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern void TF_ImportGraphDefOptionsRemapControlDependency(TF_ImportGraphDefOptions opts, string srcName, TF_Operation dst); + + /// + /// Sets any imported nodes with a given control input to have it replaced with an operation + /// + /// Node in the graph to be imported. + /// References an operation that already exists in the graph being imported. + /// + /// Set any imported nodes with control input to have that input + /// replaced with . + /// + public void RemapControlDependency(string srcName, TFOperation destination) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (srcName == null) + throw new ArgumentNullException(nameof(srcName)); + if (destination == null) + throw new ArgumentNullException(nameof(destination)); + if (destination.Handle == IntPtr.Zero) + throw new ObjectDisposedException(nameof(destination)); + TF_ImportGraphDefOptionsRemapControlDependency(handle, srcName, destination.Handle); + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions opts, byte uniquify); + + /// + /// Set whether to uniquify imported operation names. + /// + /// If set to true imported operation names will be modified if their name already exists in the graph. + /// If set to false conflicting names will be treated as an error. + /// + /// + /// Note that this option has no effect if a prefix is set, since the prefix will guarantee all names are + /// Defaults to false. + /// + public void SetUniquifyNames(bool uniquifyNames) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + + TF_ImportGraphDefOptionsSetUniquifyNames(handle, uniquifyNames ? (byte)1 : (byte)0); + } + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions opts, byte uniquify_prefix); + + /// + /// Sets the uniquify prefix. This option has no effect if no prefix is specified. + /// + /// If set to true the specified prefix will be modified if it already exists as an + /// operation name or prefix in the graph. + /// If set to false a conflicting prefix will be treated as an error. + /// + public void SetUniquifyPrefix(bool uniquifyPrefix) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + TF_ImportGraphDefOptionsSetUniquifyPrefix(handle, uniquifyPrefix ? (byte)1 : (byte)0); + } + } + + /// + /// Drives the execution of a graph + /// + /// + /// + /// This creates a new context to execute a TFGraph. You can use the + /// constructor to create an empty session, or you can load an existing + /// model using the static method in this class. + /// + /// + /// To execute operations with the graph, call the method + /// which returns an object that you can use to build the operation by providing + /// the inputs, requesting the operations that you want to execute and the desired outputs. + /// + /// + /// The method is a high-level helper function that wraps a + /// call to the method which just takes too many parameters that must + /// be kept in sync. + /// + /// + internal class TFSession : TFDisposableThreadSafe + { + // extern TF_Session * TF_NewSession (TF_Graph *graph, const TF_SessionOptions *opts, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Session TF_NewSession(TF_Graph graph, TF_SessionOptions opts, TF_Status status); + + /// + /// Gets the graph associated with this TensorFlow session. + /// + /// The graph. + public TFGraph Graph { get; private set; } + + private TFSession(IntPtr handle, TFGraph graph) : base(handle) + { + Graph = graph; + } + + /// + /// Creates a new execution session associated with the specified session graph with some configuration options. + /// + /// The Graph to which this session is associated. + /// Session options. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public TFSession(TFGraph graph, TFSessionOptions sessionOptions, TFStatus status = null) : base(IntPtr.Zero) + { + Graph = graph; + var cstatus = TFStatus.Setup(status); + var h = TF_NewSession(graph.handle, sessionOptions.handle, cstatus.handle); + cstatus.CheckMaybeRaise(status); + handle = h; + } + + /// + /// Creates a new execution session associated with the specified session graph. + /// + /// The Graph to which this session is associated. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public TFSession(TFGraph graph, TFStatus status = null) : base(IntPtr.Zero) + { + Graph = graph; + var cstatus = TFStatus.Setup(status); + TF_Status h; + using (var empty = new TFSessionOptions()) + { + h = TF_NewSession(graph.handle, empty.Handle, cstatus.handle); + } + cstatus.CheckMaybeRaise(status); + handle = h; + } + + /// + /// Creates a new execution session with an empty graph + /// + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// The created graph can be retrieved using the Graph property on the session. + /// + public TFSession(TFStatus status = null) : this(new TFGraph(), status) + { + } + + // extern TF_Session * TF_LoadSessionFromSavedModel (const TF_SessionOptions *session_options, const TF_Buffer *run_options, const char *export_dir, const char *const *tags, int tags_len, TF_Graph *graph, TF_Buffer *meta_graph_def, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_Session TF_LoadSessionFromSavedModel(TF_SessionOptions session_options, LLBuffer* run_options, string export_dir, string[] tags, int tags_len, TF_Graph graph, LLBuffer* meta_graph_def, TF_Status status); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe TF_DeviceList TF_SessionListDevices(TF_Session session, TF_Status status); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe int TF_DeviceListCount(TF_DeviceList list); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe IntPtr TF_DeviceListName(TF_DeviceList list, int index, TF_Status status); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe IntPtr TF_DeviceListType(TF_DeviceList list, int index, TF_Status status); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe long TF_DeviceListMemoryBytes(TF_DeviceList list, int index, TF_Status status); + + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteDeviceList(TF_DeviceList list); + + /// + /// Lists available devices in this session. + /// + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public IEnumerable ListDevices(TFStatus status = null) + { + var cstatus = TFStatus.Setup(status); + var rawDeviceList = TF_SessionListDevices(Handle, cstatus.handle); + var size = TF_DeviceListCount(rawDeviceList); + + var list = new List(); + for (var i = 0; i < size; i++) + { + var name = Marshal.PtrToStringAnsi(TF_DeviceListName(rawDeviceList, i, cstatus.handle)); + var deviceType = (DeviceType)Enum.Parse(typeof(DeviceType), Marshal.PtrToStringAnsi(TF_DeviceListType(rawDeviceList, i, cstatus.handle))); + var memory = TF_DeviceListMemoryBytes(rawDeviceList, i, cstatus.handle); + + list.Add(new DeviceAttributes(name, deviceType, memory)); + } + + TF_DeleteDeviceList(rawDeviceList); + + return list; + } + + /// + /// Creates a session and graph from a model stored in the SavedModel file format. + /// + /// On success, this populates the provided with the contents of the graph stored in the specified model and with the MetaGraphDef of the loaded model. + /// Session options to use for the new session. + /// Options to use to initialize the state (can be null). + /// must be set to the path of the exported SavedModel. + /// must include the set of tags used to identify one MetaGraphDef in the SavedModel. + /// This must be a newly created graph. + /// On success, this will be populated on return with the contents of the MetaGraphDef (can be null). + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// + /// This function creates a new session using the specified and then initializes + /// the state (restoring tensors and other assets) using . + /// + /// + /// This function loads the data that was saved using the SavedModel file format, as described + /// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md + /// + /// + public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null) + { + if (graph == null) + throw new ArgumentNullException(nameof(graph)); + if (tags == null) + throw new ArgumentNullException(nameof(tags)); + if (exportDir == null) + throw new ArgumentNullException(nameof(exportDir)); + if (metaGraphDef == null) + throw new ArgumentNullException(nameof(metaGraphDef)); + var cstatus = TFStatus.Setup(status); + unsafe + { + var h = TF_LoadSessionFromSavedModel(sessionOptions.handle, runOptions == null ? null : runOptions.LLBuffer, exportDir, tags, tags.Length, graph.handle, metaGraphDef == null ? null : metaGraphDef.LLBuffer, cstatus.handle); + + if (cstatus.CheckMaybeRaise(status)) + { + return new TFSession(h, graph); + } + } + return null; + } + + // extern void TF_CloseSession (TF_Session *, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_CloseSession(TF_Session session, TF_Status status); + + /// + /// Closes the session. Contacts any other processes associated with the session, if applicable. + /// + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// Can not be called after calling DeleteSession. + /// + public void CloseSession(TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + var cstatus = TFStatus.Setup(status); + TF_CloseSession(handle, cstatus.handle); + cstatus.CheckMaybeRaise(status); + } + + // extern void TF_DeleteSession (TF_Session *, TF_Status *status); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_DeleteSession(TF_Session session, TF_Status status); + + /// + /// Deletes the session. + /// + /// Status. + public void DeleteSession(TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + var cstatus = TFStatus.Setup(status); + TF_DeleteSession(handle, cstatus.handle); + cstatus.CheckMaybeRaise(status); + } + + internal override void NativeDispose(IntPtr handle) + { + using (var s = new TFStatus()) + { + TF_DeleteSession(handle, s.handle); + } + } + + // extern void TF_SessionRun (TF_Session *session, const TF_Buffer *run_options, const TF_Output *inputs, TF_Tensor *const *input_values, int ninputs, const TF_Output *outputs, TF_Tensor **output_values, int noutputs, const TF_Operation *const *target_opers, int ntargets, TF_Buffer *run_metadata, TF_Status *); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe void TF_SessionRun(TF_Session session, LLBuffer* run_options, TFOutput[] inputs, TF_Tensor[] input_values, int ninputs, TFOutput[] outputs, TF_Tensor[] output_values, int noutputs, TF_Operation[] target_opers, int ntargets, LLBuffer* run_metadata, TF_Status status); + + /// + /// Use the runner class to easily configure inputs, outputs and targets to be passed to the session runner. + /// + /// + /// + /// The runner has a simple API that allows developers to call the AddTarget, AddInput, AddOutput and Fetch + /// to construct the parameters that will be passed to the TFSession.Run method. + /// + /// + /// Instances of this class are created by calling the GetRunner method on the TFSession. + /// + /// + /// The various methods in this class return an instance to the Runner itsel, to allow + /// to easily construct chains of execution like this: + /// + /// + /// var result = session.GetRunner ().AddINput (myInput).Fetch (MyOutput).Run (); + /// + /// + /// You do not need to chain the operations, this works just the same: + /// + /// + /// runner = session.GetRunner (); + /// runner.AddInput(myInput); + /// runner.Fetch(myOutput); + /// var results = runner.Run(); + /// + /// + public class Runner + { + private List inputs; + private List outputs; + private List inputValues; + private List targets; + private TFSession session; + + internal Runner(TFSession session) + { + inputs = new List(); + outputs = new List(); + inputValues = new List(); + targets = new List(); + this.session = session; + RunMetadata = null; + RunOptions = null; + } + + /// + /// Adds an input to the session + /// + /// An instance to the runner, so you can easily chain the operations together. + /// Incoming port. + /// Value to assing to the incoming port. + public Runner AddInput(TFOutput input, TFTensor value) + { + if (value == null) + throw new ArgumentNullException(nameof(value)); + inputs.Add(input); + inputValues.Add(value); + return this; + } + + /// + /// Adds an input to the session specified by name, with an optional index in the operation (separated by a colon). + /// + /// An instance to the runner, so you can easily chain the operations together. + /// Incoming port, with an optional index separated by a colon. + /// Value to assing to the incoming port. + public Runner AddInput(string input, TFTensor value) + { + if (value == null) + throw new ArgumentNullException(nameof(value)); + inputs.Add(ParseOutput(input)); + inputValues.Add(value); + return this; + } + + /// + /// Adds the specified operations as the ones to be retrieved. + /// + /// An instance to the runner, so you can easily chain the operations together. + /// One or more targets. + public Runner AddTarget(params TFOperation[] targets) + { + foreach (var t in targets) + this.targets.Add(t); + return this; + } + + // Parses user strings that contain both the operation name and an index. + private TFOutput ParseOutput(string operation) + { + var p = operation.IndexOf(':'); + if (p != -1 && p != operation.Length - 1) + { + var op = operation.Substring(0, p); + if (int.TryParse(operation.Substring(p + 1), out var idx)) + { + return session.Graph[op][idx]; + } + } + return session.Graph[operation][0]; + } + + /// + /// Adds the specified operation names as the ones to be retrieved. + /// + /// An instance to the runner, so you can easily chain the operations together. + /// One or more target names. + public Runner AddTarget(params string[] targetNames) + { + foreach (var tn in targetNames) + targets.Add(session.Graph[tn]); + return this; + } + + /// + /// Makes the Run method return the index-th output of the tensor referenced by operation. + /// + /// The instance of runner, to allow chaining operations. + /// The name of the operation in the graph. + /// The index of the output in the operation. + public Runner Fetch(string operation, int index) + { + var op = session.Graph[operation]; + outputs.Add(op[index]); + return this; + } + + /// + /// Makes the Run method return the output of the tensor referenced by operation, the operation string can contain the output index. + /// + /// The instance of runner, to allow chaining operations. + /// The name of the operation in the graph, which might be a simple name, or it might be name:index, + /// where the index is the . + public Runner Fetch(string operation) + { + var op = ParseOutput(operation); + outputs.Add(op); + return this; + } + + /// + /// Makes the Run method return the output of the tensor referenced by output + /// + /// The instance of runner, to allow chaining operations. + /// The output referencing a specified tensor. + public Runner Fetch(TFOutput output) + { + outputs.Add(output); + return this; + } + + /// + /// Makes the Run method return the output of all the tensor referenced by outputs. + /// + /// The instance of runner, to allow chaining operations. + /// The outputs referencing a specified tensor. + public Runner Fetch(params TFOutput[] outputs) + { + foreach (var output in outputs) + this.outputs.Add(output); + return this; + } + + /// + /// Makes the Run method return the output of all the tensor referenced by outputs. + /// + /// The instance of runner, to allow chaining operations. + /// The output sreferencing a specified tensor. + public Runner Fetch(params string[] outputs) + { + foreach (var output in outputs) + this.outputs.Add(ParseOutput(output)); + return this; + } + + /// + /// Protocol buffer encoded block containing the metadata passed to the method. + /// + public TFBuffer RunMetadata; + + /// + /// Protocol buffer encoded block containing the run options passed to the method. + /// + public TFBuffer RunOptions; + + /// + /// Execute the graph fragments necessary to compute all requested fetches. + /// + /// One TFTensor for each call to Fetch that you made, in the order that you made them. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public TFTensor[] Run(TFStatus status = null) + { + return session.Run(inputs.ToArray(), inputValues.ToArray(), outputs.ToArray(), targets.ToArray(), RunMetadata, RunOptions, status); + } + + /// + /// Run the specified operation, by adding it implicity to the output, single return value + /// + /// The output of the operation. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + /// + /// This method is a convenience method, and when you call it, it will clear any + /// calls that you might have done to Fetch() and use the specified operation to Fetch + /// instead. + /// + public TFTensor Run(TFOutput operation, TFStatus status = null) + { + outputs.Clear(); + Fetch(operation); + return Run(status)[0]; + } + + } + + /// + /// Gets a new runner, this provides a simpler API to prepare the inputs to run on a session + /// + /// The runner. + /// + /// The runner has a simple API that allows developers to call the AddTarget, AddInput, AddOutput and Fetch + /// to construct the parameters that will be passed to the TFSession.Run method. + /// + /// The Run method will return an array of TFTensor values, one for each invocation to the Fetch method. + /// + public Runner GetRunner() + { + return new Runner(this); + } + + /// + /// Executes a pipeline given the specified inputs, inputValues, outputs, targetOpers, runMetadata and runOptions. + /// A simpler API is available by calling the method which performs all the bookkeeping + /// necessary. + /// + /// An array of tensors fetched from the requested outputs. + /// Inputs nodes. + /// Input values. + /// Output nodes. + /// Target operations to execute. + /// Run metadata, a buffer containing the protocol buffer encoded value for https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/core/protobuf/config.proto. + /// Run options, a buffer containing the protocol buffer encoded value for https://github.com/tensorflow/tensorflow/blob/r1.9/tensorflow/core/protobuf/config.proto. + /// Status buffer, if specified a status code will be left here, if not specified, a exception is raised if there is an error. + public TFTensor[] Run(TFOutput[] inputs, TFTensor[] inputValues, TFOutput[] outputs, TFOperation[] targetOpers = null, TFBuffer runMetadata = null, TFBuffer runOptions = null, TFStatus status = null) + { + if (handle == IntPtr.Zero) + ObjectDisposedException(); + if (inputs == null) + throw new ArgumentNullException(nameof(inputs)); + if (inputValues == null) + throw new ArgumentNullException(nameof(inputValues)); + if (outputs == null) + throw new ArgumentNullException(nameof(outputs)); + int iLen = inputs.Length; + if (iLen != inputValues.Length) + throw new ArgumentException("inputs and inputValues have different lengths", "inputs"); + int oLen = outputs.Length; + + // runOptions and runMetadata might be null + var cstatus = TFStatus.Setup(status); + + // Create arrays for the unmanaged versions + var ivals = new IntPtr[iLen]; + for (int i = 0; i < iLen; i++) + ivals[i] = inputValues[i].handle; + + // I believe this might not be necessary, the output values in TF_SessionRun looks like a write-only result + var ovals = new IntPtr[outputs.Length]; + IntPtr[] topers = null; + int tLen = 0; + if (targetOpers != null) + { + tLen = targetOpers.Length; + topers = new IntPtr[tLen]; + for (int i = 0; i < tLen; i++) + topers[i] = targetOpers[i].Handle; + } + + unsafe + { + TF_SessionRun(handle, runOptions == null ? null : runOptions.LLBuffer, inputs, ivals, iLen, outputs, ovals, oLen, topers, tLen, runMetadata == null ? null : runMetadata.LLBuffer, cstatus.handle); + } + cstatus.CheckMaybeRaise(status); + + // Ensure that the input tensors remain rooted, so that the GC won't collect & run finalizers between + // when they are copied to ivals and TF_SessionRun is called. + GC.KeepAlive(inputValues); + + var result = new TFTensor[oLen]; + for (int i = 0; i < oLen; i++) + { + result[i] = new TFTensor(ovals[i]); + } + return result; + } + } + + /// + /// The data type for a specific tensor. + /// + /// + /// Tensors have uniform data types, all the elements of the tensor are of this + /// type and they dictate how TensorFlow will treat the data stored. + /// + internal enum TFDataType : uint + { + /// + /// The TFDataType has not been set + /// + Unknown = 0, + + /// + /// Single precission floatint point, 32-bits (C# float) + /// + Float = 1, + /// + /// Double precission floatint point, 64-bits (C# double) + /// + Double = 2, + /// + /// 32-bit signed integers (C# int) + /// + Int32 = 3, + /// + /// 8 bit unsigned integers (C# byte) + /// + UInt8 = 4, + /// + /// 16-bit signed integers (C# short) + /// + Int16 = 5, + /// + /// 8-bit signed integers (C# sbyte) + /// + Int8 = 6, + /// + /// Binary blob + /// + String = 7, + /// + /// Single precission complex numbers (32-bit floats) + /// + Complex64 = 8, + /// + /// 32-bit float based complex numbers + /// + Complex = 8, + /// + /// 64-bit signed integers (C# long) + /// + Int64 = 9, + /// + /// 8-bit boolean (C# bool) + /// + Bool = 10, + /// + /// Quantized 8-bit signed integer + /// + QInt8 = 11, + /// + /// Quantized 8-bit unsigned integer + /// + QUInt8 = 12, + /// + /// Quantized 32-bit signed integer + /// + QInt32 = 13, + /// + /// Float32 truncated to 16 bits. Only for cast operations. + /// + BFloat16 = 14, + /// + /// Quantized 16-bit signed integer + /// + QInt16 = 15, + /// + /// Quantized 16-bit unsigned integer + /// + QUInt16 = 16, + /// + /// 16-bit unsigned integers (C# long) + /// + UInt16 = 17, + /// + /// Double precission complex numbers (32-bit floats) + /// + Complex128 = 18, + + /// + /// Half floats - 16-bit half precision floating point. + /// + Half = 19, + + /// + /// Handle to a mutable resource. + /// + Resource = 20, + + /// + /// Variant data type + /// + Variant = 21, + + /// + /// 32-bit unsigned integers + /// + UInt32 = 22, + + /// + /// 64-bit unsigned integers + /// + UInt64 = 23 + } + + /// + /// Status code for invoking a tensorflow operation. + /// + internal enum TFCode : uint + { + /// + /// Not an error; returned on success + /// + Ok = 0, + /// + /// The operation was cancelled (typically by the caller). + /// + Cancelled = 1, + /// + /// Unknown error. An example of where this error may be returned is + /// if a Status value received from another address space belongs to + /// an error-space that is not known in this address space. Also + /// errors raised by APIs that do not return enough error information + /// may be converted to this error. + /// + Unknown = 2, + + /// + /// Client specified an invalid argument. Note that this differs + /// from FailedPrecondition. InvalidArgumentindicates arguments + /// that are problematic regardless of the state of the system + /// (e.g., a malformed file name). + /// + InvalidArgument = 3, + + /// + /// Deadline expired before operation could complete. For operations + /// that change the state of the system, this error may be returned + /// even if the operation has completed successfully. For example, a + /// successful response from a server could have been delayed long + /// enough for the deadline to expire. + /// + DeadlineExceeded = 4, + + /// + /// Some requested entity (e.g., file or directory) was not found. + /// For privacy reasons, this code may be returned when the client + /// does not have the access right to the entity. + /// + NotFound = 5, + + /// + /// Some entity that we attempted to create (e.g., file or directory) already exists. + /// + AlreadyExists = 6, + + /// + /// The caller does not have permission to execute the specified + /// operation. PermissionDenied must not be used for rejections + /// caused by exhausting some resource (use ResourceExhausted + /// instead for those errors). PermissionDeniedmust not be + /// used if the caller can not be identified (use Unauthenticated + /// instead for those errors). + /// + PermissionDenied = 7, + + /// + /// The request does not have valid authentication credentials for the + /// operation. + /// + Unauthenticated = 16, + + /// + /// Some resource has been exhausted, perhaps a per-user quota, or + /// perhaps the entire file system is out of space. + /// + ResourceExhausted = 8, + + /// + /// Operation was rejected because the system is not in a state + /// required for the operation's execution. For example, directory + /// to be deleted may be non-empty, an rmdir operation is applied to + /// a non-directory, etc. + /// + /// A litmus test that may help a service implementor in deciding + /// between FailedPrecondition, Aborted, and Unavailable: + /// + /// (a) Use Unavailableif the client can retry just the failing call. + /// (b) Use Aborted if the client should retry at a higher-level + /// (e.g., restarting a read-modify-write sequence). + /// (c) Use FailedPrecondition if the client should not retry until + /// the system state has been explicitly fixed. E.g., if an "rmdir" + /// fails because the directory is non-empty, FailedPrecondition + /// should be returned since the client should not retry unless + /// they have first fixed up the directory by deleting files from it. + /// (d) Use FailedPrecondition if the client performs conditional + /// REST Get/Update/Delete on a resource and the resource on the + /// server does not match the condition. E.g., conflicting + /// read-modify-write on the same resource. + /// + FailedPrecondition = 9, + + /// + /// The operation was aborted, typically due to a concurrency issue + /// like sequencer check failures, transaction aborts, etc. + /// + /// See litmus test above for deciding between FailedPrecondition, + /// Aborted and Unavailable + /// + Aborted = 10, + + /// + /// Operation tried to iterate past the valid input range. E.g., seeking or + /// reading past end of file. + /// + /// Unlike InvalidArgument, this error indicates a problem that may + /// be fixed if the system state changes. For example, a 32-bit file + /// system will generate InvalidArgument if asked to read at an + /// offset that is not in the range [0,2^32-1], but it will generate + /// OutOfRange if asked to read from an offset past the current + /// file size. + /// + /// There is a fair bit of overlap between FailedPrecondition and + /// OutOfRange. We recommend using OutOfRane (the more specific + /// error) when it applies so that callers who are iterating through + /// a space can easily look for an OutOfRange error to detect when + /// they are done. + /// + OutOfRange = 11, + + /// + /// Operation is not implemented or not supported/enabled in this service. + /// + Unimplemented = 12, + + /// + /// Internal errors. Means some invariants expected by underlying + /// system has been broken. If you see one of these errors, + /// something is very broken. + /// + Internal = 13, + + /// + /// The service is currently unavailable. This is a most likely a + /// transient condition and may be corrected by retrying with + /// a backoff. + /// + /// See litmus test above for deciding between FailedPrecondition, + /// Aborted, and Unavailable. + /// + Unavailable = 14, + + /// + /// Unrecoverable data loss or corruption. + /// + DataLoss = 15 + } + + /// + /// Represents a specific input of an operation. + /// + [StructLayout(LayoutKind.Sequential)] + internal struct TFInput + { + /// + /// The operation that this input is for + /// + public unsafe TF_Operation Operation; + + /// + /// The index of the output within the Operation + /// + public int Index; + + // extern TF_Output TF_OperationInput (TF_Input oper_in); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern TFOutput TF_OperationInput(TFInput oper_in); + + public TFOutput GetOutput(TFInput operIn) + { + return TF_OperationInput(operIn); + } + + // extern TF_DataType TF_OperationInputType (TF_Input oper_in); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern TFDataType TF_OperationInputType(TFInput oper_in); + + public TFDataType InputType => TF_OperationInputType(this); + + } + + /// + /// Represents a specific output of an operation on a tensor. + /// + /// + /// + /// TFOutput objects represent one of the outputs of an operation in the graph + /// (TFGraph). Outputs have a data type, and eventually a shape that you can + /// retrieve by calling the method. + /// + /// + /// These can be passed as an input argument to a function for adding operations + /// to a graph, or to the TFSession's Run and GetRunner method as values to be + /// fetched. + /// + /// + [StructLayout(LayoutKind.Sequential)] + internal struct TFOutput + { + private unsafe TF_Operation LLOperation; + + /// + /// The index of the output within the operation. + /// + public int Index; + + // extern int TF_OperationOutputNumConsumers (TF_Output oper_out); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern int TF_OperationOutputNumConsumers(TFOutput oper_out); + + /// + /// Gets the number consumers. + /// + /// The number consumers. + /// + /// This number can change when new operations are added to the graph. + /// + public int NumConsumers => TF_OperationOutputNumConsumers(this); + + // extern TF_DataType TF_OperationOutputType (TF_Output oper_out); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern TFDataType TF_OperationOutputType(TFOutput oper_out); + + /// + /// Gets the type of the output. + /// + /// The type of the output. + public TFDataType OutputType => LLOperation == IntPtr.Zero ? TFDataType.Unknown : TF_OperationOutputType(this); + + /// + /// Initializes a new TFOutput instance. + /// + /// The operation to which to attach the output. + /// The index of the output within the operation, if not specified, it defaults to zero. + public TFOutput(TFOperation operation, int index = 0) + { + if (operation == null) + throw new ArgumentNullException(nameof(operation)); + LLOperation = operation.Handle; + Index = index; + } + + /// + /// Initializes a new TFOutput instance from another TFOutput + /// + /// The other TFOutput that is having its operation attached. + /// The index of the output within the operation, if not specified, it defaults to zero. + public TFOutput(TFOutput output, int index = 0) + { + if (output.LLOperation == null) + throw new ArgumentNullException("Outputs does not have a valid operation pointer"); + LLOperation = output.LLOperation; + Index = index; + } + + // extern int TF_OperationOutputConsumers (TF_Output oper_out, TF_Input *consumers, int max_consumers); + [DllImport(NativeBinding.TensorFlowLibrary)] + private static extern unsafe int TF_OperationOutputConsumers(TFOutput oper_out, TFInput* consumers, int max_consumers); + + /// + /// Get list of all current consumers of a specific output of an operation + /// + /// The output consumers. + /// + /// A concurrent modification of the graph can increase the number of consumers of + /// an operation. + /// This can return null if the TFOutput does not point to a valid object. + /// + public TFInput[] OutputConsumers + { + get + { + var result = new TFInput[NumConsumers]; + unsafe + { + fixed (TFInput* first = &result[0]) + TF_OperationOutputConsumers(this, first, result.Length); + } + return result; + } + } + + /// + /// The associated operation. + /// + /// The operation. + public TFOperation Operation => new TFOperation(null, LLOperation); + + /// + /// Returns a that represents the current . + /// + /// A that represents the current . + public override string ToString() + { + return string.Format("[{3} Index={1} Operation={2} (0x{0:X})]", (long)LLOperation, Index, Operation, OutputType); + } + } + + /// + /// Low-level: Enumeration describing the types of a metadata attribute + /// + internal enum TFAttributeType : uint + { + /// + /// The type of the attribute is a string + /// + String = 0, + + /// + /// The type of the attribute is an int. + /// + Int = 1, + + /// + /// The type of the attribute is a float + /// + Float = 2, + + /// + /// The type of the attribute is a bool. + /// + Bool = 3, + + /// + /// The type of the attribute is a type. + /// + Type = 4, + + /// + /// The type of the attribute is a tensor shape + /// + Shape = 5, + + /// + /// The type of the attribute is a tensor + /// + Tensor = 6, + + /// + /// The type of the attribute is a placeholder + /// + Placeholder = 7, + + /// + /// The type of the attribute is a function + /// + Func = 8 + } + + /// + /// Low-level: this describes the tensorflow type information for an attribute in the low-level attributes used by operations. + /// + /// + /// This is a low-level operation returned by the . + /// This is included for completeness, but is not generally used from C#, as you have access to the high-level + /// bindings in the type. + /// + [StructLayout(LayoutKind.Sequential)] + internal struct TFAttributeMetadata + { + private byte isList; + public bool IsList => isList != 0; + public long ListSize; + public TFAttributeType Type; + public long TotalSize; + + /// + /// Returns a that represents the current . + /// + /// A that represents the current . + public override string ToString() + { + return string.Format($"[TFAttributeMetadata IsList={IsList} ListSize={ListSize} Type={Type} TotalSize={TotalSize}]"); + } + } + + /// + /// Represents the shape of a tensor, it describes how many dimensions the tensor has in a given axis + /// + /// + /// + /// The shapes can be created by calling the constructor with the number of dimensions + /// in the shape. The null value is used to specify that the shape is unknown, + /// an empty array is used to create a scalar, and other values are used to specify + /// the number of dimensions. + /// + /// + /// For the Unknown case, you can use , for + /// scalars, you can use the shape. + /// + /// + /// To create a 2-element vector, use: + /// new TFShape (2) + /// + /// + /// To create a 2x3 matrix, use: + /// new TFShape (2, 3) + /// + /// + /// To create a shape with an unknown number of elements, you can pass the value + /// -1. This is typically used to indicate the shape of tensors that represent a + /// variable-sized batch of values. + /// + /// + /// To create a matrix with 4 columns and an unknown number of rows: + /// var batch = new TFShape (-1, 4) + /// + /// + internal class TFShape + { + /// + /// Represents an unknown number of dimensions in the tensor. + /// + /// The unknown. + public static TFShape Unknown => new TFShape(null); + + /// + /// This shape is used to represent scalar values. + /// + /// The scalar. + public static TFShape Scalar => new TFShape(new long[0]); + + internal long[] dims; + + /// + /// Initializes a new instance of the class. + /// + /// This is a params argument, so you can provide multiple values to it. + /// A null value means that this is an unknown shape, a single value is used to create a vector, + /// two values are used to create a 2-D matrix and so on. + /// + /// + /// + /// + public TFShape(params long[] args) + { + dims = args; + } + + /// + /// Gets the length of the specified dimension in the tensor + /// + /// The length, -1 for shapes that have an unknown dimension. + /// Dimension. + public int GetLength(int dimension) => dims == null ? -1 : dims.GetLength(dimension); + + /// + /// Number of dimensions represented by this shape. + /// + /// The number dimensions, -1 if the number of dimensions is unknown, 0 if the shape represent a scalar, 1 for a vector, 2 for a matrix and so on.. + public int NumDimensions => dims == null ? -1 : dims.Length; + + /// + /// Gets a value indicating whether all the dimensions in the are fully specified. + /// + /// true if is fully specified; otherwise, false. + public bool IsFullySpecified + { + get + { + if (dims == null) + return false; + foreach (var j in dims) + if (j == -1) + return false; + return true; + } + } + + /// + /// Returns the shape as an array + /// + /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension. + public long[] ToArray() + { + if (dims == null) + return null; + + var ret = (long[])dims.Clone(); + return ret; + } + + /// + /// Returns the shape as an array + /// + /// null if the shape represents an unknown shape, otherwise an array with N elements, one per dimension, and each element can be either -1 (if the dimension size is unspecified) or the size of the dimension. + public int[] ToIntArray() + { + if (dims == null) + return null; + + var ret = new int[dims.Length]; + for (int i = 0; i < dims.Length; i++) + { + checked + { + ret[i] = (int)dims[i]; + } + } + return ret; + } + + /// + /// Gets a value indicating whether one of the dimensions in the shape is larger than Int32.MaxValue. + /// + /// true if is long array; otherwise, false. + public bool IsLongArray + { + get + { + foreach (var l in dims) + if (l > Int32.MaxValue) + return true; + + return false; + } + } + + /// + /// Returns a that represents the current . + /// + /// A that represents the current . + public override string ToString() + { + if (dims == null) + return "unknown"; + return "[" + String.Join(", ", dims.Select(x => x == -1 ? "?" : x.ToString())) + "]"; + } + + /// + /// Gets the dimensions for the specified index. + /// + /// Index. + public long this[int idx] => dims[idx]; + + /// + /// Returns the shape as a 1-dimensional tensor with each element corresponding to the specified shape dimension. + /// + /// The tensor. + public TFTensor AsTensor() + { + return new TFTensor(ToIntArray()); + } + + /// + /// Adds a to a , yielding a shape made up of the concatenation of the first and the second shapes. + /// + /// The first to add. + /// The second to add. + /// The that is the sum of the values of left and right. + public static TFShape operator +(TFShape left, TFShape right) + { + if (left == null) + return right; + if (right == null) + return left; + + var full = new long[left.dims.Length + right.dims.Length]; + Array.Copy(left.dims, full, left.dims.Length); + Array.Copy(right.dims, 0, full, left.dims.Length, right.dims.Length); + return new TFShape(full); + } + + /// + /// Performs an implicit conversion from to . + /// + /// The shape. + /// The result of the conversion. + public static implicit operator TFTensor(TFShape shape) + { + return shape.AsTensor(); + } + } +} diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs new file mode 100644 index 0000000000..980df5e641 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -0,0 +1,54 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Runtime.InteropServices; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Transforms.TensorFlow +{ + internal partial class TensorFlowUtils + { + internal static PrimitiveType Tf2MlNetType(TFDataType type) + { + switch (type) + { + case TFDataType.Float: + return NumberType.R4; + case TFDataType.Double: + return NumberType.R8; + case TFDataType.UInt32: + return NumberType.U4; + case TFDataType.UInt64: + return NumberType.U8; + default: + throw new NotSupportedException("TensorFlow type not supported."); + } + } + + public static unsafe void FetchData(IntPtr data, T[] result) + { + var size = result.Length; + + GCHandle handle = GCHandle.Alloc(result, GCHandleType.Pinned); + IntPtr target = handle.AddrOfPinnedObject(); + + Int64 sizeInBytes = size * Marshal.SizeOf((typeof(T))); + Buffer.MemoryCopy(data.ToPointer(), target.ToPointer(), sizeInBytes, sizeInBytes); + handle.Free(); + } + + internal static bool IsTypeSupported(TFDataType tfoutput) + { + switch (tfoutput) + { + case TFDataType.Float: + case TFDataType.Double: + return true; + default: + return false; + } + } + } +} diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs new file mode 100644 index 0000000000..0aaae57598 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -0,0 +1,390 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.IO; +using System.Linq; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.TensorFlow; + +[assembly: LoadableClass(TensorFlowTransform.Summary, typeof(IDataTransform), typeof(TensorFlowTransform), + typeof(TensorFlowTransform.Arguments), typeof(SignatureDataTransform), TensorFlowTransform.UserName, TensorFlowTransform.ShortName)] + +// This is for de-serialization from a binary model file. +[assembly: LoadableClass(typeof(TensorFlowTransform.TensorFlowMapper), null, typeof(SignatureLoadRowMapper), + "", TensorFlowTransform.TensorFlowMapper.LoaderSignature)] + +[assembly: EntryPointModule(typeof(TensorFlowTransform))] + +namespace Microsoft.ML.Transforms +{ + public static class TensorFlowTransform + { + internal sealed class TensorFlowMapper : IRowMapper + { + private readonly IHost _host; + + /// + /// TensorFlow session object + /// + private readonly TFSession _session; + + private readonly string[] _inputColNames; + private readonly int[] _inputColIndices; + private readonly bool[] _isVectorInput; + private readonly TFShape[] _tfInputShapes; + private readonly TFDataType[] _tfInputTypes; + + private readonly string _outputColName; + private readonly ColumnType _outputColType; + private readonly TFDataType _tfOutputType; + + private const int BatchSize = 1; + public const string LoaderSignature = "TFMapper"; + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TENSFLOW", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelBytes, string[] inputColNames, string outputColName) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register("TensorFlowMapper"); + _host.CheckValue(inputSchema, nameof(inputSchema)); + _host.CheckNonEmpty(modelBytes, nameof(modelBytes)); + _host.CheckNonEmpty(inputColNames, nameof(inputColNames)); + _host.CheckNonEmpty(outputColName, nameof(outputColName)); + + _session = LoadTFSession(modelBytes, null); + _host.CheckValue(_session.Graph[outputColName], nameof(outputColName), "Output does not exist in the model"); + _host.Check(inputColNames.All(name => _session.Graph[name] != null), "One of the input does not exist in the model"); + + _outputColName = outputColName; + (_outputColType, _tfOutputType) = GetOutputTypes(_session.Graph, _outputColName); + (_inputColNames, _inputColIndices, _isVectorInput, _tfInputShapes, _tfInputTypes) = GetInputMetaData(_session.Graph, inputColNames, inputSchema); + } + + public static TensorFlowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema schema) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + + var numInputs = ctx.Reader.ReadInt32(); + Contracts.CheckDecode(numInputs > 0); + + string[] source = new string[numInputs]; + for (int j = 0; j < source.Length; j++) + source[j] = ctx.LoadNonEmptyString(); + + byte[] data = null; + if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray())) + throw env.ExceptDecode(); + + var outputColName = ctx.LoadNonEmptyString(); + + return new TensorFlowMapper(env, schema, data, source, outputColName); + } + + public void Save(ModelSaveContext ctx) + { + _host.AssertValue(ctx); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + var buffer = new TFBuffer(); + _session.Graph.ToGraphDef(buffer); + + ctx.SaveBinaryStream("TFModel", w => + { + w.WriteByteArray(buffer.ToArray()); + }); + Contracts.AssertNonEmpty(_inputColNames); + ctx.Writer.Write(_inputColNames.Length); + foreach (var colName in _inputColNames) + ctx.SaveNonEmptyString(colName); + + ctx.SaveNonEmptyString(_outputColName); + } + + private TFSession LoadTFSession(byte[] modelBytes, string modelArg) + { + var graph = new TFGraph(); + try + { + graph.Import(modelBytes, ""); + } + catch (Exception ex) + { + if (!string.IsNullOrEmpty(modelArg)) + throw _host.Except($"TensorFlow exception triggered while loading model from '{modelArg}'"); +#pragma warning disable MSML_NoMessagesForLoadContext + throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); +#pragma warning restore MSML_NoMessagesForLoadContext + + } + return new TFSession(graph); + } + + private ITensorValueGetter CreateTensorValueGetter(IRow input, bool isVector, int colIndex, TFShape tfShape) + { + if (isVector) + return new TensorValueGetterVec(input, colIndex, tfShape); + else + return new TensorValueGetter(input, colIndex); + } + + private ITensorValueGetter CreateTensorValueGetter(IRow input, TFDataType tfType, bool isVector, int colIndex, TFShape tfShape) + { + var type = TFTensor.TypeFromTensorType(tfType); + _host.AssertValue(type); + return Utils.MarshalInvoke(CreateTensorValueGetter, type, input, isVector, colIndex, tfShape); + } + + private ITensorValueGetter[] GetTensorValueGetters(IRow input) + { + var srcTensorGetters = new ITensorValueGetter[_inputColIndices.Length]; + for (int j = 0; j < _inputColIndices.Length; j++) + { + int colIndex = _inputColIndices[j]; + srcTensorGetters[j] = CreateTensorValueGetter(input, _tfInputTypes[j], _isVectorInput[j], colIndex, _tfInputShapes[j]); + } + return srcTensorGetters; + } + + private Delegate MakeGetter(IRow input) + { + var type = TFTensor.TypeFromTensorType(_tfOutputType); + _host.Assert(type == _outputColType.ItemType.RawType); + return Utils.MarshalInvoke(MakeGetter, type, input, _outputColType); + } + + private Delegate MakeGetter(IRow input, ColumnType columnType) + { + _host.AssertValue(input); + _host.Assert(typeof(T) == columnType.ItemType.RawType); + + var srcTensorGetters = GetTensorValueGetters(input); + + ValueGetter> valuegetter = (ref VBuffer dst) => + { + var runner = _session.GetRunner(); + for (int i = 0; i < _inputColIndices.Length; i++) + { + var inputName = _inputColNames[i]; + runner.AddInput(inputName, srcTensorGetters[i].GetTensor()); + } + + var tensors = runner.Fetch(_outputColName).Run(); + + Contracts.Assert(tensors.Length > 0); + + var values = dst.Values; + if (Utils.Size(values) < _outputColType.VectorSize) + values = new T[_outputColType.VectorSize]; + + TensorFlowUtils.FetchData(tensors[0].Data, values); + dst = new VBuffer(values.Length, values, dst.Indices); + }; + return valuegetter; + } + + public Delegate[] CreateGetters(IRow input, Func activeOutput, out Action disposer) + { + var getters = new Delegate[1]; + disposer = null; + using (var ch = _host.Start("CreateGetters")) + { + if (activeOutput(0)) + getters[0] = MakeGetter(input); + ch.Done(); + return getters; + } + } + + public Func GetDependencies(Func activeOutput) + { + return col => activeOutput(0) && _inputColIndices.Any(i => i == col); + } + + public RowMapperColumnInfo[] GetOutputColumns() + { + return new[] { new RowMapperColumnInfo(_outputColName, _outputColType, null) }; + } + + private static (ColumnType, TFDataType) GetOutputTypes(TFGraph graph, string columnName) + { + Contracts.AssertValue(graph); + Contracts.AssertNonEmpty(columnName); + Contracts.AssertValue(graph[columnName]); + + var tfoutput = new TFOutput(graph[columnName]); + var shape = graph.GetTensorShape(tfoutput); + + int[] dims = shape.ToIntArray().Skip(shape[0] == -1 ? BatchSize : 0).ToArray(); + var type = TensorFlowUtils.Tf2MlNetType(tfoutput.OutputType); + return (new VectorType(type, dims), tfoutput.OutputType); + } + + private static (string[], int[], bool[], TFShape[], TFDataType[]) GetInputMetaData(TFGraph graph, string[] source, ISchema inputSchema) + { + var tfShapes = new TFShape[source.Length]; + var tfTypes = new TFDataType[source.Length]; + var colNames = new string[source.Length]; + var inputColIndices = new int[source.Length]; + var isInputVector = new bool[source.Length]; + for (int i = 0; i < source.Length; i++) + { + colNames[i] = source[i]; + if (!inputSchema.TryGetColumnIndex(colNames[i], out inputColIndices[i])) + throw Contracts.Except($"Column '{colNames[i]}' does not exist"); + + var tfoutput = new TFOutput(graph[colNames[i]]); + if (!TensorFlowUtils.IsTypeSupported(tfoutput.OutputType)) + throw Contracts.Except($"Input type '{tfoutput.OutputType}' of input column '{colNames[i]}' is not supported in TensorFlow"); + + tfShapes[i] = graph.GetTensorShape(tfoutput); + var type = inputSchema.GetColumnType(inputColIndices[i]); + var shape = tfShapes[i].ToIntArray().Skip(tfShapes[i][0] == -1 ? BatchSize : 0); + if (type.AsVector.DimCount == 1) + { + int valCount = shape.Aggregate((x, y) => x * y); + if (type.ValueCount != valCount) + throw Contracts.Except($"Input shape mismatch: Input '{colNames[i]}' has shape {tfShapes[i].ToString()}, but input data is of length {valCount}."); + } + else if (shape.Select((dim, j) => dim != type.AsVector.GetDim(j)).Any(b => b)) + throw Contracts.Except($"Input shape mismatch: Input '{colNames[i]}' has shape {tfShapes[i].ToString()}, but input data is {type.AsVector.ToString()}."); + + isInputVector[i] = type.IsVector; + + tfTypes[i] = tfoutput.OutputType; + + var l = new long[tfShapes[i].NumDimensions]; + for (int ishape = 0; ishape < tfShapes[i].NumDimensions; ishape++) + { + l[ishape] = tfShapes[i][ishape] == -1 ? BatchSize : tfShapes[i][ishape]; + } + tfShapes[i] = new TFShape(l); + } + return (colNames, inputColIndices, isInputVector, tfShapes, tfTypes); + } + } + + public sealed class Arguments : TransformInputBase + { + + [Argument(ArgumentType.Required, HelpText = "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", ShortName = "ModelDir", SortOrder = 0)] + public string ModelFile; + + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)] + public string[] InputColumns; + + [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the output", ShortName = "output", SortOrder = 2)] + public string OutputColumn; + } + + public const string Summary = "Transforms the data using the TensorFlow model."; + public const string UserName = "TensorFlowTransform"; + public const string ShortName = "TFTransform"; + private const string RegistrationName = "TensorFlowTransform"; + + /// + /// Convenience constructor for public facing API. + /// + /// Host Environment. + /// Input . This is the output from previous transform or loader. + /// This is the frozen TensorFlow model file. https://www.tensorflow.org/mobile/prepare_models + /// Name of the output column. Keep it same as in the TensorFlow model. + /// Name of the input column(s). Keep it same as in the TensorFlow model. + public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source) + { + return Create(env, new Arguments() { InputColumns = source, OutputColumn = name, ModelFile = modelFile }, input); + } + + /// + public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); + host.CheckValue(args, nameof(args)); + host.CheckUserArg(Utils.Size(args.InputColumns) > 0, nameof(args.InputColumns)); + for (int i = 0; i < args.InputColumns.Length; i++) + host.CheckNonWhiteSpace(args.InputColumns[i], nameof(args.InputColumns)); + host.CheckNonWhiteSpace(args.ModelFile, nameof(args.ModelFile)); + host.CheckUserArg(File.Exists(args.ModelFile), nameof(args.ModelFile)); + + var modelBytes = File.ReadAllBytes(args.ModelFile); + var mapper = new TensorFlowMapper(host, input.Schema, modelBytes, args.InputColumns, args.OutputColumn); + return new RowToRowMapperTransform(host, input, mapper); + } + + private interface ITensorValueGetter + { + TFTensor GetTensor(); + } + + private class TensorValueGetter : ITensorValueGetter + { + private readonly ValueGetter _srcgetter; + + public TensorValueGetter(IRow input, int colIndex) + { + _srcgetter = input.GetGetter(colIndex); + } + public TFTensor GetTensor() + { + var scalar = default(T); + _srcgetter(ref scalar); + return TFTensor.CreateScalar(scalar); + } + } + + private class TensorValueGetterVec : ITensorValueGetter + { + private readonly ValueGetter> _srcgetter; + private readonly TFShape _tfShape; + private VBuffer _vBuffer; + private VBuffer _vBufferDense; + public TensorValueGetterVec(IRow input, int colIndex, TFShape tfShape) + { + _srcgetter = input.GetGetter>(colIndex); + _tfShape = tfShape; + _vBuffer = default; + _vBufferDense = default; + } + public TFTensor GetTensor() + { + _srcgetter(ref _vBuffer); + _vBuffer.CopyToDense(ref _vBufferDense); + return TFTensor.Create(_vBufferDense.Values, _tfShape); + } + } + + [TlcModule.EntryPoint(Name = "Transforms.TensorFlowScorer", Desc = Summary, UserName = UserName, ShortName = ShortName)] + public static CommonOutputs.TransformOutput TensorFlowScorer(IHostEnvironment env, Arguments input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(input, nameof(input)); + + var h = EntryPointUtils.CheckArgsAndCreateHost(env, "TensorFlow", input); + var view = Create(h, input, input.Data); + return new CommonOutputs.TransformOutput() + { + Model = new TransformModel(h, view, input.Data), + OutputData = view + }; + } + } +} diff --git a/src/Microsoft.ML.TensorFlow/doc.xml b/src/Microsoft.ML.TensorFlow/doc.xml new file mode 100644 index 0000000000..29d7ca2844 --- /dev/null +++ b/src/Microsoft.ML.TensorFlow/doc.xml @@ -0,0 +1,74 @@ + + + + + + + Extracts hidden layers' values from a pre-trained Tensorflow model. + + + The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained Tensorflow model. + The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model. + + The TensorflowTransform has following assumptions regarding the input, output and processing of data. + + + The transform currently accepts the frozen TensorFlow model file as input. + + The transform supports scoring only one example at a time. + The name of input column(s) should match the name of input(s) in Tensorflow model. + The name of the output column should match one of the operations in the Tensorflow graph. + Currently, float and double are the only acceptable data types for input/output. + + Upon success, the transform will introduce a new column in based on the name of the output column specified. + + + + + + + + pipeline.Add(new TextLoader(dataFile).CreateFrom<MNISTData>(useHeader: false)); + pipeline.Add(new ColumnCopier(("NumericImageVec", "Input"); + pipeline.Add(new TensorFlowScorer() + { + ModelFile = model_location; + InputColumns = new []{ "Input" }; + OutputColumn = "Output" + } + + + + + var pipeline = new LearningPipeline(seed: 1); + pipeline.Add(new TextLoader(dataFile).CreateFrom<CifarData>(useHeader: false)); + pipeline.Add(new ImageLoader(("ImagePath", "ImageReal")) + { + ImageFolder = imageFolder + }); + + pipeline.Add(new ImageResizer(("ImageReal", "ImageCropped")) + { + ImageHeight = imageHeight, + ImageWidth = imageWidth, + Resizing = ImageResizerTransformResizingKind.IsoCrop + }); + + pipeline.Add(new ImagePixelExtractor(("ImageCropped", "Input")) + { + UseAlpha = false, + InterleaveArgb = true + }); + + pipeline.Add(new TensorFlowScorer() + { + ModelFile = model_location, + InputColumns = new[] { "Input" }, + OutputColumn = "Output" + }); + + + + + + \ No newline at end of file diff --git a/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj b/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj index 8aa272922c..5da86f946e 100644 --- a/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj +++ b/src/Microsoft.ML.Transforms/Microsoft.ML.Transforms.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 diff --git a/src/Microsoft.ML/CSharpApi.cs b/src/Microsoft.ML/CSharpApi.cs index 78d19e3608..66718150d1 100644 --- a/src/Microsoft.ML/CSharpApi.cs +++ b/src/Microsoft.ML/CSharpApi.cs @@ -1534,6 +1534,18 @@ public void Add(Microsoft.ML.Transforms.SupervisedBinNormalizer input, Microsoft _jsonNodes.Add(Serialize("Transforms.SupervisedBinNormalizer", input, output)); } + public Microsoft.ML.Transforms.TensorFlowScorer.Output Add(Microsoft.ML.Transforms.TensorFlowScorer input) + { + var output = new Microsoft.ML.Transforms.TensorFlowScorer.Output(); + Add(input, output); + return output; + } + + public void Add(Microsoft.ML.Transforms.TensorFlowScorer input, Microsoft.ML.Transforms.TensorFlowScorer.Output output) + { + _jsonNodes.Add(Serialize("Transforms.TensorFlowScorer", input, output)); + } + public Microsoft.ML.Transforms.TextFeaturizer.Output Add(Microsoft.ML.Transforms.TextFeaturizer input) { var output = new Microsoft.ML.Transforms.TextFeaturizer.Output(); @@ -15904,6 +15916,81 @@ public SupervisedBinNormalizerPipelineStep(Output output) } } + namespace Transforms + { + + /// + /// Transforms the data using the TensorFlow model. + /// + public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.CommonInputs.ITransformInput, Microsoft.ML.ILearningPipelineItem + { + + + /// + /// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details. + /// + public string ModelFile { get; set; } + + /// + /// The names of the model inputs + /// + public string[] InputColumns { get; set; } + + /// + /// The name of the output + /// + public string OutputColumn { get; set; } + + /// + /// Input dataset + /// + public Var Data { get; set; } = new Var(); + + + public sealed class Output : Microsoft.ML.Runtime.EntryPoints.CommonOutputs.ITransformOutput + { + /// + /// Transformed dataset + /// + public Var OutputData { get; set; } = new Var(); + + /// + /// Transform model + /// + public Var Model { get; set; } = new Var(); + + } + public Var GetInputData() => Data; + + public ILearningPipelineStep ApplyStep(ILearningPipelineStep previousStep, Experiment experiment) + { + if (previousStep != null) + { + if (!(previousStep is ILearningPipelineDataStep dataStep)) + { + throw new InvalidOperationException($"{ nameof(TensorFlowScorer)} only supports an { nameof(ILearningPipelineDataStep)} as an input."); + } + + Data = dataStep.Data; + } + Output output = experiment.Add(this); + return new TensorFlowScorerPipelineStep(output); + } + + private class TensorFlowScorerPipelineStep : ILearningPipelineDataStep + { + public TensorFlowScorerPipelineStep(Output output) + { + Data = output.OutputData; + Model = output.Model; + } + + public Var Data { get; } + public Var Model { get; } + } + } + } + namespace Transforms { public enum TextTransformLanguage diff --git a/src/Native/build.proj b/src/Native/build.proj index ab338ac855..5e9d7989cb 100644 --- a/src/Native/build.proj +++ b/src/Native/build.proj @@ -3,7 +3,6 @@ true - x64 True @@ -26,21 +25,14 @@ - $(BaseOutputPath)$(TargetArchitecture).$(Configuration)\Native - - lib - .dll - .so - .dylib - - .pdb - .so.dbg - .dylib.dwarf - - win - linux - osx - $(PackageRid)-$(TargetArchitecture) + lib + .dll + .so + .dylib + + .pdb + .so.dbg + .dylib.dwarf + + + + netstandard2.0 + + + + + + + + + + + + + + + + + + + + + + + + + <_downloadFiles Include="@(TensorFlowArchive);@(AdditionalDownloadFile)" Url="%(Identity)" DestinationFile="%(DownloadFile)" /> + + + + + + + + + + + <_filesToCheckSum Include="@(TensorFlowArchive->'%(DownloadFile)')" DestinationPath="%(DownloadShaFile)" /> + + + + + + + + + + + + $([System.IO.File]::ReadAllText('%(LocalShaFile)')) + $([System.IO.File]::ReadAllText('%(DownloadShaFile)')) + + + + + + + + + + + + + + + + + + + + <_fileFromArchive Include="%(TensorFlowArchive.FilesFromArchive)" ExtractDirectory="%(TensorFlowArchive.ExtractDirectory)" Runtime="%(TensorFlowArchive.Runtime)" /> + <_fileFromArchive DestinationFile="%(FileName)%(Extension)"/> + + <_fileFromArchive Condition="'%(Runtime)' == 'osx-x64' AND '%(Extension)' == '.so'" DestinationFile="%(FileName).dylib" /> + <_fileFromArchive PackagePath="runtimes\%(_fileFromArchive.Runtime)\native\%(_fileFromArchive.DestinationFile)" /> + + + <_fileFromArchive Condition="'%(DestinationFile)' == 'LICENSE'" PackagePath="THIRD_PARTY_NOTICES.txt" Runtime="" /> + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-darwin-x86_64-1.10.0.tar.gz.sha b/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-darwin-x86_64-1.10.0.tar.gz.sha new file mode 100644 index 0000000000..da8b53866b --- /dev/null +++ b/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-darwin-x86_64-1.10.0.tar.gz.sha @@ -0,0 +1 @@ +77218EC4DA96A73B15B8AA5637C9F21B389510A9FAF4DCF06DF5B81A5403015C6BA3EEE29BD8BA5B0694F40C671D8E6722D554C4F93F95C33F29AB491C70263C \ No newline at end of file diff --git a/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-linux-x86_64-1.10.0.tar.gz.sha b/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-linux-x86_64-1.10.0.tar.gz.sha new file mode 100644 index 0000000000..6b865984ec --- /dev/null +++ b/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-linux-x86_64-1.10.0.tar.gz.sha @@ -0,0 +1 @@ +B9E9CD95BC6A28297ACAB0D684FBBFAFF1F9AE893432AC2D208120D767101AC20E2C55BC79E59DBE6E5BD9EC802026694960FA12137BB303061C5A21B62BD29E \ No newline at end of file diff --git a/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-windows-x86_64-1.10.0.zip.sha b/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-windows-x86_64-1.10.0.zip.sha new file mode 100644 index 0000000000..92ce0db9fb --- /dev/null +++ b/src/Redist/Microsoft.ML.TensorFlow.Redist/libtensorflow-cpu-windows-x86_64-1.10.0.zip.sha @@ -0,0 +1 @@ +66F3A9522917076038AE9CCA11FE805DD516C60B3A3E156B78C2E4BD0E3E5785A9D0380C5E06411473EF14A72B72FD93F954AA3496A12D1FAF0FA3393970E700 \ No newline at end of file diff --git a/src/Redist/build.proj b/src/Redist/build.proj new file mode 100644 index 0000000000..6891516cc0 --- /dev/null +++ b/src/Redist/build.proj @@ -0,0 +1,6 @@ + + + + + + diff --git a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv index 700c55b28a..14cab2b984 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv +++ b/test/BaselineOutput/Common/EntryPoints/core_ep-list.tsv @@ -124,6 +124,7 @@ Transforms.Scorer Turn the predictor model into a transform model Microsoft.ML.R Transforms.Segregator Un-groups vector columns into sequences of rows, inverse of Group transform Microsoft.ML.Runtime.Data.GroupingOperations Ungroup Microsoft.ML.Runtime.Data.UngroupTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.SentimentAnalyzer Uses a pretrained sentiment model to score input strings Microsoft.ML.Runtime.Transforms.TextAnalytics AnalyzeSentiment Microsoft.ML.Runtime.TextAnalytics.SentimentAnalyzingTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.SupervisedBinNormalizer Similar to BinNormalizer, but calculates bins based on correlation with the label column, not equi-density. The new value is bin_number / number_of_bins. Microsoft.ML.Runtime.Data.Normalize SupervisedBin Microsoft.ML.Runtime.Data.NormalizeTransform+SupervisedBinArguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput +Transforms.TensorFlowScorer Transforms the data using the TensorFlow model. Microsoft.ML.Transforms.TensorFlowTransform TensorFlowScorer Microsoft.ML.Transforms.TensorFlowTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.TextFeaturizer A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are normalized counts of (word and/or character) ngrams in a given tokenized text. Microsoft.ML.Runtime.Transforms.TextAnalytics TextTransform Microsoft.ML.Runtime.Data.TextTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.TextToKeyConverter Converts input values (words, numbers, etc.) to index in a dictionary. Microsoft.ML.Runtime.Data.Categorical TextToKey Microsoft.ML.Runtime.Data.TermTransform+Arguments Microsoft.ML.Runtime.EntryPoints.CommonOutputs+TransformOutput Transforms.TrainTestDatasetSplitter Split the dataset into train and test sets Microsoft.ML.Runtime.EntryPoints.TrainTestSplit Split Microsoft.ML.Runtime.EntryPoints.TrainTestSplit+Input Microsoft.ML.Runtime.EntryPoints.TrainTestSplit+Output diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 3c8814d60c..99db39b419 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21879,6 +21879,76 @@ "ITransformOutput" ] }, + { + "Name": "Transforms.TensorFlowScorer", + "Desc": "Transforms the data using the TensorFlow model.", + "FriendlyName": "TensorFlowTransform", + "ShortName": "TFTransform", + "Inputs": [ + { + "Name": "ModelFile", + "Type": "String", + "Desc": "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", + "Aliases": [ + "ModelDir" + ], + "Required": true, + "SortOrder": 0.0, + "IsNullable": false + }, + { + "Name": "InputColumns", + "Type": { + "Kind": "Array", + "ItemType": "String" + }, + "Desc": "The names of the model inputs", + "Aliases": [ + "inputs" + ], + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "Data", + "Type": "DataView", + "Desc": "Input dataset", + "Required": true, + "SortOrder": 1.0, + "IsNullable": false + }, + { + "Name": "OutputColumn", + "Type": "String", + "Desc": "The name of the output", + "Aliases": [ + "output" + ], + "Required": true, + "SortOrder": 2.0, + "IsNullable": false + } + ], + "Outputs": [ + { + "Name": "OutputData", + "Type": "DataView", + "Desc": "Transformed dataset" + }, + { + "Name": "Model", + "Type": "TransformModel", + "Desc": "Transform model" + } + ], + "InputKind": [ + "ITransformInput" + ], + "OutputKind": [ + "ITransformOutput" + ] + }, { "Name": "Transforms.TextFeaturizer", "Desc": "A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are normalized counts of (word and/or character) ngrams in a given tokenized text.", diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index be16385901..736c75c241 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -16,6 +16,7 @@ + @@ -26,6 +27,8 @@ + + \ No newline at end of file diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 34922a4ab1..24e53757a7 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -957,5 +957,42 @@ public void TestOvaMacroWithUncalibratedLearner() } } } + + [Fact] + public void TestTensorFlowEntryPoint() + { + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + using (var env = new TlcEnvironment(42)) + { + var experiment = env.CreateExperiment(); + + var importInput = new ML.Data.TextLoader(dataPath); + importInput.Arguments.Column = new TextLoaderColumn[] + { + new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } }, + new TextLoaderColumn { Name = "Placeholder", Source = new[] { new TextLoaderRange(1, 784) } } + }; + var importOutput = experiment.Add(importInput); + + var tfTransformInput = new ML.Transforms.TensorFlowScorer + { + Data = importOutput.Data, + InputColumns = new[] { "Placeholder" }, + OutputColumn = "Softmax", + ModelFile = "mnist_model/frozen_saved_model.pb" + }; + var tfTransformOutput = experiment.Add(tfTransformInput); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(tfTransformOutput.OutputData); + + var schema = data.Schema; + Assert.Equal(3, schema.ColumnCount); + Assert.Equal("Softmax", schema.GetColumnName(2)); + Assert.Equal(10, schema.GetColumnType(2).VectorSize); + } + } } } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index 3fa16b11c9..5f4a579914 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3732,5 +3732,18 @@ public void EntryPointWordEmbeddings() } } } + + [Fact] + public void EntryPointTensorFlowTransform() + { + TestEntryPointPipelineRoutine(GetDataPath("Train-Tiny-28x28.txt"), "col=Label:R4:0 col=Placeholder:R4:1-784", + new[] { "Transforms.TensorFlowScorer" }, + new[] + { + @"'InputColumns': [ 'Placeholder' ], + 'ModelFile': 'mnist_model/frozen_saved_model.pb', + 'OutputColumn': 'Softmax'" + }); + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index d8d5cd3d5d..39301dc429 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -15,6 +15,7 @@ + @@ -26,9 +27,14 @@ + + - \ No newline at end of file + + + + diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs new file mode 100644 index 0000000000..d7d2e9a2de --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -0,0 +1,83 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Runtime.LightGBM; +using Microsoft.ML.Transforms; +using System.Collections.Generic; +using System.IO; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + [Fact(Skip = "Disabled due to this bug https://github.com/dotnet/machinelearning/issues/770")] + public void TensorFlowTransformCifarLearningPipelineTest() + { + var imageHeight = 32; + var imageWidth = 32; + var model_location = "cifar_model/frozen_model.pb"; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + + var pipeline = new LearningPipeline(seed: 1); + pipeline.Add(new Microsoft.ML.Data.TextLoader(dataFile).CreateFrom(useHeader: false)); + pipeline.Add(new ImageLoader(("ImagePath", "ImageReal")) + { + ImageFolder = imageFolder + }); + + pipeline.Add(new ImageResizer(("ImageReal", "ImageCropped")) + { + ImageHeight = imageHeight, + ImageWidth = imageWidth, + Resizing = ImageResizerTransformResizingKind.IsoCrop + }); + + pipeline.Add(new ImagePixelExtractor(("ImageCropped", "Input")) + { + UseAlpha = false, + InterleaveArgb = true + }); + + pipeline.Add(new TensorFlowScorer() + { + ModelFile = model_location, + InputColumns = new[] { "Input" }, + OutputColumn = "Output" + }); + + using (var environment = new TlcEnvironment()) + { + IDataView trans = pipeline.Execute(environment); + Assert.NotNull(trans); + + trans.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = trans.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + while (cursor.MoveNext()) + { + getter(ref buffer); + Assert.Equal(10, buffer.Length); + } + } + } + } + } + + public class CifarData + { + [Column("0")] + public string ImagePath; + + [Column("1")] + public string Name; + } +} diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs new file mode 100644 index 0000000000..ed450cd8c0 --- /dev/null +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.ImageAnalytics; +using Microsoft.ML.Runtime.LightGBM; +using Microsoft.ML.Transforms; +using System.Collections.Generic; +using System.IO; +using Xunit; + +namespace Microsoft.ML.Scenarios +{ + public partial class ScenariosTests + { + private class TestData + { + [VectorType(4)] + public float[] a; + [VectorType(4)] + public float[] b; + } + + [Fact] + public void TensorFlowTransformMatrixMultiplicationTest() + { + var model_location = "model_matmul/frozen_saved_model.pb"; + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline + var loader = ComponentCreation.CreateDataView(env, + new List(new TestData[] { new TestData() { a = new[] { 1.0f, 2.0f, + 3.0f, 4.0f }, + b = new[] { 1.0f, 2.0f, + 3.0f, 4.0f } }, + new TestData() { a = new[] { 2.0f, 2.0f, + 2.0f, 2.0f }, + b = new[] { 3.0f, 3.0f, + 3.0f, 3.0f } } })); + + var trans = TensorFlowTransform.Create(env, loader, model_location, "c", "a", "b"); + + using (var cursor = trans.GetRowCursor(a => true)) + { + var cgetter = cursor.GetGetter>(2); + Assert.True(cursor.MoveNext()); + VBuffer c = default; + cgetter(ref c); + + Assert.Equal(1.0 * 1.0 + 2.0 * 3.0, c.Values[0]); + Assert.Equal(1.0 * 2.0 + 2.0 * 4.0, c.Values[1]); + Assert.Equal(3.0 * 1.0 + 4.0 * 3.0, c.Values[2]); + Assert.Equal(3.0 * 2.0 + 4.0 * 4.0, c.Values[3]); + + Assert.True(cursor.MoveNext()); + c = default; + cgetter(ref c); + + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[0]); + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[1]); + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[2]); + Assert.Equal(2.0 * 3.0 + 2.0 * 3.0, c.Values[3]); + + Assert.False(cursor.MoveNext()); + + } + } + } + + [Fact] + public void TensorFlowTransformMNISTConvTest() + { + var model_location = "mnist_model/frozen_saved_model.pb"; + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); + + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] + { + new TextLoader.Column() + { + Name = "Label", + Source = new [] { new TextLoader.Range() { Min=0, Max=0} }, + Type = DataKind.Num + }, + + new TextLoader.Column() + { + Name = "Placeholder", + Source = new [] { new TextLoader.Range() { Min=1, Max=784} }, + Type = DataKind.Num + } + } + }, new MultiFileSource(dataPath)); + + IDataView trans = TensorFlowTransform.Create(env, loader, model_location, "Softmax", "Placeholder"); + trans = new ConcatTransform(env, trans, "reshape_input", "Placeholder"); + trans = TensorFlowTransform.Create(env, trans, model_location, "dense/Relu", "reshape_input"); + trans = new ConcatTransform(env, trans, "Features", "Softmax", "dense/Relu"); + + var trainer = new LightGbmMulticlassTrainer(env, new LightGbmArguments()); + + var cached = new CacheDataView(env, trans, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); + + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); + + Assert.Equal(0.99, metrics.AccuracyMicro, 2); + Assert.Equal(0.99, metrics.AccuracyMicro, 2); + + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); + + var sample1 = new MNISTData() + { + Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26, + 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253, + 225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253, + 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198, + 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0, + 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } + }; + + var prediction = model.Predict(sample1); + + float max = -1; + int maxIndex = -1; + for(int i=0;i max) + { + max = prediction.PredictedLabels[i]; + maxIndex = i; + } + } + + Assert.Equal(5, maxIndex); + } + } + + public class MNISTData + { + [Column("1")] + public float Label; + + [VectorType(784)] + public float[] Placeholder; + } + + public class MNISTPrediction + { + [ColumnName("Score")] + public float[] PredictedLabels; + } + + [Fact] + public void TensorFlowTransformCifar() + { + var model_location = "cifar_model/frozen_model.pb"; + + using (var env = new TlcEnvironment()) + { + var imageHeight = 32; + var imageWidth = 32; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var images = new ImageLoaderTransform(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] + { + new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } + }, + ImageFolder = imageFolder + }, data); + var cropped = new ImageResizerTransform(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ + new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} + } + }, images); + + var pixels = new ImagePixelExtractorTransform(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ + new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true} + } + }, cropped); + + + IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input"); + + trans.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = trans.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + while (cursor.MoveNext()) + { + getter(ref buffer); + Assert.Equal(10, buffer.Length); + } + } + } + } + } +}