Skip to content

TensorFlowMapper transform for scoring Tensorflow models in ML.NET #704

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 61 commits into from
Aug 30, 2018
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
f6baa5b
Merge pull request #1 from dotnet/master
abgoswam Jul 18, 2018
7d0ea81
Merge pull request #2 from dotnet/master
abgoswam Jul 30, 2018
bad9cd2
Merge pull request #3 from dotnet/master
abgoswam Aug 20, 2018
6b76960
creating dummy file to test permissions. will remove
abgoswam Aug 21, 2018
085cf6c
test
yaeldMS Aug 21, 2018
4175209
TensorFlow scoring, from Zeeshan A.'s branch, with some additional ch…
yaeldMS Aug 21, 2018
3e7d118
Merge pull request #4 from dotnet/master
abgoswam Aug 22, 2018
30dcdc5
creating dummy file to test permissions. will remove
abgoswam Aug 21, 2018
b80750b
test
yaeldMS Aug 21, 2018
58c703a
TensorFlow scoring, from Zeeshan A.'s branch, with some additional ch…
yaeldMS Aug 21, 2018
032dc48
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
abgoswam Aug 22, 2018
1f83474
taking care of review comments; build fixes
abgoswam Aug 22, 2018
bf7b3a7
simple change intended to trigger fresh builds to repro OSX-Release b…
abgoswam Aug 22, 2018
3c0fc04
Prevent input tensors from being GC'ed before TF_SessionRun is called
ericstj Aug 24, 2018
c6ecc72
Remove Tensorflow models from tests data
ericstj Aug 24, 2018
54de166
Add entry point
yaeldMS Aug 24, 2018
2ed4575
Merge remote-tracking branch 'upstream/master'
yaeldMS Aug 24, 2018
fb5b607
Merge branch 'master' into agoswami/tensorflow
yaeldMS Aug 24, 2018
55aca30
Fix manifest and generated C# API.
yaeldMS Aug 24, 2018
3f2e3ae
Create a redist package for tensorflow
ericstj Aug 21, 2018
7b53f44
Make TF redist project a normal MSBuild project
ericstj Aug 23, 2018
fba6c91
Add TF License file to package
ericstj Aug 23, 2018
e22b819
Don't use fullpaths when un-tar'ing
ericstj Aug 23, 2018
ada8787
Add some logging to TF redist project
ericstj Aug 23, 2018
2444743
Fix casing of TF redist proj
ericstj Aug 24, 2018
de4f6ea
Change tests to use redistributed TensorFlow
ericstj Aug 24, 2018
35500ba
Fix mac / linux tensorflow redist
ericstj Aug 24, 2018
b9ddacc
Only include LICENSE if it exists
ericstj Aug 24, 2018
a0caf2f
unit test to verify TF transform works with ML.NET image transforms
abgoswam Aug 26, 2018
7a5da64
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
abgoswam Aug 26, 2018
cdb5892
Factor TensorflowTransform into its own assembly/package
ericstj Aug 27, 2018
88f54c2
Update Tensorflow to TensorFlow
yaeldMS Aug 27, 2018
b69af6c
Fix manifest and C# API. Also, add entry point unit test.
yaeldMS Aug 27, 2018
8938aa4
update test case for image transforms; still skipping test for now ti…
abgoswam Aug 27, 2018
136dca4
fix conflicts
abgoswam Aug 27, 2018
c1db655
Use IDataView instead of var in unit test.
yaeldMS Aug 27, 2018
919a408
fix have from Tensorflow to TensorFlow
abgoswam Aug 27, 2018
182e036
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
abgoswam Aug 27, 2018
7f6a495
Merge remote-tracking branch 'upstream/master' into agoswami/tensorflow
yaeldMS Aug 27, 2018
413be27
Fix unit test to use new TextLoader APIs.
yaeldMS Aug 27, 2018
0002083
Add unit test using the pipeline API.
yaeldMS Aug 27, 2018
dd3e00a
Validate input dimensions.
yaeldMS Aug 27, 2018
b6d8c74
Added XML doc for TensorflowTransform.
zeahmed Aug 27, 2018
1830655
Remove extra dimension for batch size in output.
yaeldMS Aug 28, 2018
e3a8e50
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
yaeldMS Aug 28, 2018
7fa60e1
Merge remote-tracking branch 'upstream/master' into agoswami/tensorflow
yaeldMS Aug 28, 2018
8f9a8df
Fix input/output validation.
yaeldMS Aug 28, 2018
4d40f94
Introduced BatchSize as constant to replace 1 everywhere.
zeahmed Aug 28, 2018
e23c7be
enabling unit test of TensorFlowTransform working with ML.NET Image* …
abgoswam Aug 28, 2018
f61486b
Extended XML docs with detailed information.
zeahmed Aug 28, 2018
7b03215
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
zeahmed Aug 28, 2018
878874b
Ensure we include actual TF license in package
ericstj Aug 29, 2018
afdf0cb
Change input validation to validate all dimensions in case of multi d…
yaeldMS Aug 29, 2018
acef318
Corrected typos in the documentation.
zeahmed Aug 29, 2018
3fea9da
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
zeahmed Aug 29, 2018
39c3533
Fix LICENSE inclusion in package.
ericstj Aug 29, 2018
145db2d
Add symbols package and fix package reference to redist
ericstj Aug 29, 2018
e991fa9
Address pull request comments.
yaeldMS Aug 29, 2018
228fb24
Merge branch 'agoswami/tensorflow' of https://github.com/abgoswam/mac…
yaeldMS Aug 29, 2018
5a849c4
Give more details in input dimension mismatch error message.
yaeldMS Aug 29, 2018
08da76d
Added a test for LearningPipelineAPI and updated the doc.xml
zeahmed Aug 30, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions Microsoft.ML.sln
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTe
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTests.netcoreapp", "test\Microsoft.ML.CpuMath.UnitTests.netcoreapp\Microsoft.ML.CpuMath.UnitTests.netcoreapp.csproj", "{5F81A2A4-73AD-494C-B387-07D605EC8826}"
EndProject

Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow", "src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj", "{570A0B8A-5463-44D2-8521-54C0CA4CACA9}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -390,6 +391,14 @@ Global
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.Build.0 = Debug|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.ActiveCfg = Release|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.Build.0 = Release|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -426,14 +435,15 @@ Global
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{7333EDEF-4144-405C-A5EC-6F42201857D8} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{A0E562A9-0E6D-470D-B180-6EB44BA84D60} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{5F81A2A4-73AD-494C-B387-07D605EC8826} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
Expand Down
11 changes: 10 additions & 1 deletion build.proj
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
<TraversalBuildDependsOn>
CreateOrUpdateCurrentVersionFile;
RestoreProjects;
BuildRedist;
BuildNative;
$(TraversalBuildDependsOn);
DownloadExternalTestFiles;
Expand All @@ -44,9 +45,17 @@
Properties="MSBuildWarningsAsMessages=NU1503" />
</Target>

<Target Name="BuildRedist"
Condition="'$(SkipRedistBuild)' != 'true'"
DependsOnTargets="RestoreProjects">
<Message Importance="High" Text="Building redist components..." />
<MSBuild Projects="src/Redist/build.proj"
Targets="Build" />
</Target>

<Target Name="BuildNative"
Condition="'$(SkipNativeBuild)' != 'true'"
DependsOnTargets="RestoreProjects">
DependsOnTargets="RestoreProjects;BuildRedist">
<Message Importance="High" Text="Building native components..." />
<MSBuild Projects="src/Native/build.proj"
Targets="Build" />
Expand Down
1 change: 1 addition & 0 deletions build/Dependencies.props
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
<MlNetMklDepsPackageVersion>0.0.0.5</MlNetMklDepsPackageVersion>
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>
<BenchmarkDotNetVersion>0.11.0</BenchmarkDotNetVersion>
<TensorFlowVersion>1.10.0</TensorFlowVersion>
</PropertyGroup>
</Project>
5 changes: 4 additions & 1 deletion build/sign.proj
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@

<!-- If we are not signing nuget packages we default to sign binaries -->
<ItemGroup Condition="'$(SignNugetPackages)' != 'true'">
<FilesToSign Include="$(OutDir)**/*.dll">
<!-- Don't sign tensorflow since we don't build it. -->
<ExcludeFilesToSign Include="$(OutDir)**/tensorflow.dll" />

<FilesToSign Include="$(OutDir)**/*.dll" Exclude="@(ExcludeFilesToSign)">
<Authenticode>Microsoft</Authenticode>
</FilesToSign>
</ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">

<PropertyGroup>
<Authors>The TensorFlow Authors</Authors>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageDescription>$(MSBuildProjectName) contains the TensorFlow C library version $(TensorFlowVersion) redistributed as a NuGet package.</PackageDescription>
<PackageLicenseUrl>https://github.com/tensorflow/tensorflow/blob/master/LICENSE</PackageLicenseUrl>
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
<Copyright>Copyright 2018 The TensorFlow Authors. All rights reserved.</Copyright>
<PackageProjectUrl>https://www.tensorflow.org</PackageProjectUrl>
<PackageReleaseNotes>https://github.com/tensorflow/tensorflow/releases/tag/v$(TensorFlowVersion)</PackageReleaseNotes>
<PackageTags>$(PackageTags) TensorFlow</PackageTags>
<!-- TODO: consider PackageIconUrl -->
</PropertyGroup>

<ItemGroup>
<Content Include="..\common\CommonPackage.props" Pack="true" PackagePath="build\netstandard2.0\$(MSBuildProjectName).props" />
<Content Include="$(PackageAssetsPath)$(PackageIdFolderName)\LICENSE.txt" Pack="true" PackagePath=".\" />
<Content Include="$(PackageAssetsPath)$(PackageIdFolderName)\THIRD_PARTY_NOTICES.txt" Pack="true" PackagePath=".\" />
</ItemGroup>
</Project>
12 changes: 12 additions & 0 deletions pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
Copy link
Member

@eerhardt eerhardt Aug 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericstj - We should also have a ".symbols" pkgproj file. That way a symbols package gets produced and is uploaded to the symbols server for the managed assemblies in this package. See the other folders for an example. #Resolved


<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<PackageDescription>Microsoft.ML.TensorFlow contains ML.NET integration of TensorFlow.</PackageDescription>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.TensorFlow.Redist\Microsoft.ML.TensorFlow.Redist.nupkgproj" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
<Project DefaultTargets="Pack">

<Import Project="Microsoft.ML.TensorFlow.nupkgproj" />

</Project>
9 changes: 9 additions & 0 deletions src/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
<WarningsNotAsErrors>$(WarningsNotAsErrors);1591</WarningsNotAsErrors>

<CodeAnalysisRuleSet>$(MSBuildThisFileDirectory)\Source.ruleset</CodeAnalysisRuleSet>

<TargetArchitecture Condition="'$(TargetArchitecture)' == ''">x64</TargetArchitecture>

<NativeAssetsBuiltPath>$(BaseOutputPath)$(TargetArchitecture).$(Configuration)\Native</NativeAssetsBuiltPath>

<PackageRid Condition="'$(OS)' == 'Windows_NT'">win</PackageRid>
<PackageRid Condition="'$(OS)' != 'Windows_NT'">linux</PackageRid>
<PackageRid Condition="$([MSBuild]::IsOSPlatform('osx'))">osx</PackageRid>
<PackageRid>$(PackageRid)-$(TargetArchitecture)</PackageRid>
</PropertyGroup>

<ItemGroup>
Expand Down
31 changes: 31 additions & 0 deletions src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>netstandard2.0</TargetFramework>
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
<DefineConstants>CORECLR</DefineConstants>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
</ItemGroup>

<ItemGroup>
<Compile Update="TensorFlow\TensorGeneric.cs">
<DesignTime>True</DesignTime>
<AutoGen>True</AutoGen>
<DependentUpon>TensorGeneric.tt</DependentUpon>
</Compile>
<None Update="TensorFlow\TensorGeneric.tt">
<Generator>TextTemplatingFileGenerator</Generator>
<LastGenOutput>TensorGeneric.cs</LastGenOutput>
</None>
</ItemGroup>

<ItemGroup>
<Service Include="{508349b6-6b84-4df5-91f0-309beebad82d}" />
</ItemGroup>

</Project>
211 changes: 211 additions & 0 deletions src/Microsoft.ML.TensorFlow/TensorFlow/Buffer.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System;
using System.Runtime.InteropServices;
using System.Text;
using size_t = System.UIntPtr;

#pragma warning disable MSML_GeneralName
#pragma warning disable MSML_ParameterLocalVarName

namespace Microsoft.ML.Transforms.TensorFlow
{
/// <summary>
/// This attribute can be applied to callback functions that will be invoked
/// from unmanaged code to managed code.
/// </summary>
/// <remarks>
/// <code>
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
/// </code>
/// </remarks>
internal sealed class MonoPInvokeCallbackAttribute : Attribute
{
/// <summary>
/// Use this constructor to annotate the type of the callback function that
/// will be invoked from unmanaged code.
/// </summary>
/// <param name="t">T.</param>
public MonoPInvokeCallbackAttribute(Type t) { }
}

[StructLayout(LayoutKind.Sequential)]
internal struct LLBuffer
{
internal IntPtr data;
internal size_t length;
internal IntPtr data_deallocator;
}

/// <summary>
/// Holds a block of data, suitable to pass, or retrieve from TensorFlow.
/// </summary>
/// <remarks>
/// <para>
/// Use the TFBuffer to blobs of data into TensorFlow, or to retrieve blocks
/// of data out of TensorFlow.
/// </para>
/// <para>
/// There are two constructors to wrap existing data, one to wrap blocks that are
/// pointed to by an IntPtr and one that takes a byte array that we want to wrap.
/// </para>
/// <para>
/// The empty constructor can be used to create a new TFBuffer that can be populated
/// by the TensorFlow library and returned to user code.
/// </para>
/// <para>
/// Typically, the data consists of a serialized protocol buffer, but other data
/// may also be held in a buffer.
/// </para>
/// </remarks>
// TODO: the string ctor
// TODO: perhaps we should have an implicit byte [] conversion that just calls ToArray?
internal class TFBuffer : TFDisposable
{
// extern TF_Buffer * TF_NewBufferFromString (const void *proto, size_t proto_len);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe LLBuffer* TF_NewBufferFromString(IntPtr proto, IntPtr proto_len);

// extern TF_Buffer * TF_NewBuffer ();
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe LLBuffer* TF_NewBuffer();

internal TFBuffer(IntPtr handle) : base(handle) { }

/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> class.
/// </summary>
public unsafe TFBuffer() : base((IntPtr)TF_NewBuffer())
{
}

/// <summary>
/// Signature of the method that is invoked to release the data.
/// </summary>
/// <remarks>
/// Methods of this signature are invoked with the data pointer and the
/// lenght pointer when then TFBuffer no longer needs to hold on to the
/// data. If you are using this on platforms with static compilation
/// like iOS, you need to annotate your callback with the MonoPInvokeCallbackAttribute,
/// like this:
///
/// <code>
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
/// </code>
/// </remarks>
public delegate void BufferReleaseFunc(IntPtr data, IntPtr lenght);

/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> by wrapping the unmanaged resource pointed by the buffer.
/// </summary>
/// <param name="buffer">Pointer to the data that will be wrapped.</param>
/// <param name="size">The size of the buffer to wrap.</param>
/// <param name="release">Optional, if not null, this method will be invoked to release the block.</param>
/// <remarks>
/// This constructor wraps the buffer as a the data to be held by the <see cref="T:TensorFlow.TFBuffer"/>,
/// if the release parameter is null, then you must ensure that the data is not released before the TFBuffer
/// is no longer in use. If the value is not null, the provided method will be invoked to release
/// the data when the TFBuffer is disposed, or the contents of the buffer replaced.
/// </remarks>
public unsafe TFBuffer(IntPtr buffer, long size, BufferReleaseFunc release) : base((IntPtr)TF_NewBuffer())
{
LLBuffer* buf = (LLBuffer*)handle;
buf->data = buffer;
buf->length = (size_t)size;
if (release == null)
buf->data_deallocator = IntPtr.Zero;
else
buf->data_deallocator = Marshal.GetFunctionPointerForDelegate(release);
}

[MonoPInvokeCallback(typeof(BufferReleaseFunc))]
internal static void FreeBlock(IntPtr data, IntPtr length)
{
Marshal.FreeHGlobal(data);
}

internal static IntPtr FreeBufferFunc;
internal static BufferReleaseFunc FreeBlockDelegate;

static TFBuffer()
{
FreeBlockDelegate = FreeBlock;
FreeBufferFunc = Marshal.GetFunctionPointerForDelegate<BufferReleaseFunc>(FreeBlockDelegate);
}

/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> by making a copy of the provided byte array.
/// </summary>
/// <param name="buffer">Buffer of data that will be wrapped.</param>
/// <remarks>
/// This constructor makes a copy of the data into an unmanaged buffer,
/// so the byte array is not pinned.
/// </remarks>
public TFBuffer(byte[] buffer) : this(buffer, 0, buffer.Length) { }

/// <summary>
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> by making a copy of the provided byte array.
/// </summary>
/// <param name="buffer">Buffer of data that will be wrapped.</param>
/// <param name="start">Starting offset into the buffer to wrap.</param>
/// <param name="count">Number of bytes from the buffer to keep.</param>
/// <remarks>
/// This constructor makes a copy of the data into an unmanaged buffer,
/// so the byte array is not pinned.
/// </remarks>
public TFBuffer(byte[] buffer, int start, int count) : this()
{
if (start < 0 || start >= buffer.Length)
throw new ArgumentException("start");
if (count < 0 || count > buffer.Length - start)
throw new ArgumentException("count");
unsafe
{
LLBuffer* buf = LLBuffer;
buf->data = Marshal.AllocHGlobal(count);
Marshal.Copy(buffer, start, buf->data, count);
buf->length = (size_t)count;
buf->data_deallocator = FreeBufferFunc;
}
}

internal unsafe LLBuffer* LLBuffer => (LLBuffer*)handle;

// extern void TF_DeleteBuffer (TF_Buffer *);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe void TF_DeleteBuffer(LLBuffer* buffer);

internal override void NativeDispose(IntPtr handle)
{
unsafe { TF_DeleteBuffer((LLBuffer*)handle); }
}

// extern TF_Buffer TF_GetBuffer (TF_Buffer *buffer);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern unsafe LLBuffer TF_GetBuffer(LLBuffer* buffer);

/// <summary>
/// Returns a byte array representing the data wrapped by this buffer.
/// </summary>
/// <returns>The array.</returns>
public byte[] ToArray()
{
if (handle == IntPtr.Zero)
return null;

unsafe
{
var lb = (LLBuffer*)handle;

var result = new byte[(int)lb->length];
Marshal.Copy(lb->data, result, 0, (int)lb->length);

return result;
}
}
}
}
Loading