Skip to content

Commit 5ef7a08

Browse files
abgoswamyaeldMS
authored andcommitted
TensorFlowMapper transform for scoring Tensorflow models in ML.NET (#704)
* creating dummy file to test permissions. will remove * test * TensorFlow scoring, from Zeeshan A.'s branch, with some additional changes. * creating dummy file to test permissions. will remove * test * TensorFlow scoring, from Zeeshan A.'s branch, with some additional changes. * taking care of review comments; build fixes * simple change intended to trigger fresh builds to repro OSX-Release build failure * Prevent input tensors from being GC'ed before TF_SessionRun is called * Remove Tensorflow models from tests data Instead bring these from a nuget package created in another repo. For now this is https://github.com/ericstj/machinelearning-testdata-temp Soon it will be https://github.com/dotnet/machinelearning-testdata * Add entry point * Fix manifest and generated C# API. * Create a redist package for tensorflow Create a nuget package that redistributes the TensorFlow C-API. This is needed because TensorFlow doesn't ship an official NuGet package. This is a straight up repack of the bits published on tensorflow.org. I made sure to apply the TensorFlow license to this package and not sign it with our authenticate certificates. * Make TF redist project a normal MSBuild project Remove the use of the SDK targets, and define our own build and clean. * Add TF License file to package * Don't use fullpaths when un-tar'ing Tar on windows was failing when msbuild passed it a full path. Workaround by using relative paths and running where we extract to. * Add some logging to TF redist project * Fix casing of TF redist proj * Change tests to use redistributed TensorFlow Also modify TF binding code to use `tensorflow` in its DLLImports. The runtime will still add the approriate prefix/extension on linux/mac. * Fix mac / linux tensorflow redist I was missing the libtensorflow_framework dependency which caused mac and linux tests to fail. After fixing that, mac still failed due to inability to load libtensorflow_framework.so. We have to rename these to .dylib to satisfy the CORECLR dllimport convention which broke the internal rpath in libtensorflow which pointed at @rpath/libtensorflow_framework.so. Fix this by rewriting the renamed libtensorflow.dylib on mac. Since this operation can only be done on mac, I had to change the build of the redist project to only build the bits appropriate for the building platform. To make this work correctly in the official build I had to make sure these platform specific builds happen when the native build happens. * Only include LICENSE if it exists LICENSE is pulled from the Windows package, so it isn't available on mac or linux. * unit test to verify TF transform works with ML.NET image transforms * Factor TensorflowTransform into its own assembly/package * Update Tensorflow to TensorFlow * Fix manifest and C# API. Also, add entry point unit test. * update test case for image transforms; still skipping test for now till we figure out why false positive passes * Use IDataView instead of var in unit test. * fix have from Tensorflow to TensorFlow * Fix unit test to use new TextLoader APIs. * Add unit test using the pipeline API. * Validate input dimensions. * Added XML doc for TensorflowTransform. * Remove extra dimension for batch size in output. * Fix input/output validation. * Introduced BatchSize as constant to replace 1 everywhere. * enabling unit test of TensorFlowTransform working with ML.NET Image* Transforms * Extended XML docs with detailed information. * Ensure we include actual TF license in package The TF zip/tarballs were missing the actual TF license. Download this and include it in the package. Rename the file from the zip/tarballs as THIRD_PARTY_NOTICES.txt as that represents its content. * Change input validation to validate all dimensions in case of multi dimensional input. * Corrected typos in the documentation. * Fix LICENSE inclusion in package. We didn't define ExtractDirectory on the item and instead had a full path as identity. This worked on linux/osx since it prepended a ""/ to the path, which was tolerated by the file system (an extra leading slash). On Windows this doesn't work or course. Fix by appending to the item that doesn't assume files came from an archive. * Add symbols package and fix package reference to redist * Address pull request comments. * Give more details in input dimension mismatch error message. * Added a test for LearningPipelineAPI and updated the doc.xml
1 parent 9258be2 commit 5ef7a08

33 files changed

+4918
-24
lines changed

Microsoft.ML.sln

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,14 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTe
103103
EndProject
104104
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CpuMath.UnitTests.netcoreapp", "test\Microsoft.ML.CpuMath.UnitTests.netcoreapp\Microsoft.ML.CpuMath.UnitTests.netcoreapp.csproj", "{5F81A2A4-73AD-494C-B387-07D605EC8826}"
105105
EndProject
106-
107-
Project("{F2A71F9B-5D33-465A-A702-920D77279786}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
106+
Project("{6EC3EE1D-3C4E-46DD-8F32-0CC8E7565705}") = "Microsoft.ML.FSharp.Tests", "test\Microsoft.ML.FSharp.Tests\Microsoft.ML.FSharp.Tests.fsproj", "{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3}"
108107
EndProject
109108
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.ImageAnalytics", "src\Microsoft.ML.ImageAnalytics\Microsoft.ML.ImageAnalytics.csproj", "{00E38F77-1E61-4CDF-8F97-1417D4E85053}"
110109
EndProject
111110
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", "src\Microsoft.ML.HalLearners\Microsoft.ML.HalLearners.csproj", "{A7222F41-1CF0-47D9-B80C-B4D77B027A61}"
112111
EndProject
112+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow", "src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj", "{570A0B8A-5463-44D2-8521-54C0CA4CACA9}"
113+
EndProject
113114
Global
114115
GlobalSection(SolutionConfigurationPlatforms) = preSolution
115116
Debug|Any CPU = Debug|Any CPU
@@ -390,6 +391,14 @@ Global
390391
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release|Any CPU.Build.0 = Release|Any CPU
391392
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
392393
{A7222F41-1CF0-47D9-B80C-B4D77B027A61}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
394+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
395+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug|Any CPU.Build.0 = Debug|Any CPU
396+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
397+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
398+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.ActiveCfg = Release|Any CPU
399+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.Build.0 = Release|Any CPU
400+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
401+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
393402
EndGlobalSection
394403
GlobalSection(SolutionProperties) = preSolution
395404
HideSolutionNode = FALSE
@@ -426,14 +435,15 @@ Global
426435
{001F3B4E-FBE4-4001-AFD2-A6A989CD1C25} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
427436
{DCF46B79-1FDB-4DBA-A263-D3D64E3AAA27} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
428437
{BF66A305-DF10-47E4-8D81-42049B149D2B} = {D3D38B03-B557-484D-8348-8BADEE4DF592}
438+
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
439+
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
429440
{7333EDEF-4144-405C-A5EC-6F42201857D8} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
430441
{A0E562A9-0E6D-470D-B180-6EB44BA84D60} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
431442
{5F81A2A4-73AD-494C-B387-07D605EC8826} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
432-
{B4E55B2D-2A92-46E7-B72F-E76D6FD83440} = {7F13E156-3EBA-4021-84A5-CD56BA72F99E}
433-
{3E4ABF07-7970-4BE6-B45B-A13D3C397545} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
434443
{802233D6-8CC0-46AD-9F23-FEE1E9AED9B3} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
435444
{00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
436445
{A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
446+
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
437447
EndGlobalSection
438448
GlobalSection(ExtensibilityGlobals) = postSolution
439449
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}

build.proj

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
<TraversalBuildDependsOn>
3232
CreateOrUpdateCurrentVersionFile;
3333
RestoreProjects;
34+
BuildRedist;
3435
BuildNative;
3536
$(TraversalBuildDependsOn);
3637
DownloadExternalTestFiles;
@@ -44,9 +45,17 @@
4445
Properties="MSBuildWarningsAsMessages=NU1503" />
4546
</Target>
4647

48+
<Target Name="BuildRedist"
49+
Condition="'$(SkipRedistBuild)' != 'true'"
50+
DependsOnTargets="RestoreProjects">
51+
<Message Importance="High" Text="Building redist components..." />
52+
<MSBuild Projects="src/Redist/build.proj"
53+
Targets="Build" />
54+
</Target>
55+
4756
<Target Name="BuildNative"
4857
Condition="'$(SkipNativeBuild)' != 'true'"
49-
DependsOnTargets="RestoreProjects">
58+
DependsOnTargets="RestoreProjects;BuildRedist">
5059
<Message Importance="High" Text="Building native components..." />
5160
<MSBuild Projects="src/Native/build.proj"
5261
Targets="Build" />

build/Dependencies.props

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@
1111
<MlNetMklDepsPackageVersion>0.0.0.5</MlNetMklDepsPackageVersion>
1212
<SystemDrawingCommonPackageVersion>4.5.0</SystemDrawingCommonPackageVersion>
1313
<BenchmarkDotNetVersion>0.11.0</BenchmarkDotNetVersion>
14+
<TensorFlowVersion>1.10.0</TensorFlowVersion>
1415
</PropertyGroup>
1516
</Project>

build/sign.proj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,10 @@
3030

3131
<!-- If we are not signing nuget packages we default to sign binaries -->
3232
<ItemGroup Condition="'$(SignNugetPackages)' != 'true'">
33-
<FilesToSign Include="$(OutDir)**/*.dll">
33+
<!-- Don't sign tensorflow since we don't build it. -->
34+
<ExcludeFilesToSign Include="$(OutDir)**/tensorflow.dll" />
35+
36+
<FilesToSign Include="$(OutDir)**/*.dll" Exclude="@(ExcludeFilesToSign)">
3437
<Authenticode>Microsoft</Authenticode>
3538
</FilesToSign>
3639
</ItemGroup>
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
2+
3+
<PropertyGroup>
4+
<Authors>The TensorFlow Authors</Authors>
5+
<TargetFramework>netstandard2.0</TargetFramework>
6+
<PackageDescription>$(MSBuildProjectName) contains the TensorFlow C library version $(TensorFlowVersion) redistributed as a NuGet package.</PackageDescription>
7+
<PackageLicenseUrl>https://github.com/tensorflow/tensorflow/blob/master/LICENSE</PackageLicenseUrl>
8+
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
9+
<Copyright>Copyright 2018 The TensorFlow Authors. All rights reserved.</Copyright>
10+
<PackageProjectUrl>https://www.tensorflow.org</PackageProjectUrl>
11+
<PackageReleaseNotes>https://github.com/tensorflow/tensorflow/releases/tag/v$(TensorFlowVersion)</PackageReleaseNotes>
12+
<PackageTags>$(PackageTags) TensorFlow</PackageTags>
13+
<!-- TODO: consider PackageIconUrl -->
14+
</PropertyGroup>
15+
16+
<ItemGroup>
17+
<Content Include="..\common\CommonPackage.props" Pack="true" PackagePath="build\netstandard2.0\$(MSBuildProjectName).props" />
18+
<Content Include="$(PackageAssetsPath)$(PackageIdFolderName)\LICENSE.txt" Pack="true" PackagePath=".\" />
19+
<Content Include="$(PackageAssetsPath)$(PackageIdFolderName)\THIRD_PARTY_NOTICES.txt" Pack="true" PackagePath=".\" />
20+
</ItemGroup>
21+
</Project>
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
<Project Sdk="Microsoft.NET.Sdk" DefaultTargets="Pack">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<PackageDescription>Microsoft.ML.TensorFlow contains ML.NET integration of TensorFlow.</PackageDescription>
6+
</PropertyGroup>
7+
8+
<ItemGroup>
9+
<ProjectReference Include="..\Microsoft.ML.TensorFlow.Redist\Microsoft.ML.TensorFlow.Redist.nupkgproj" />
10+
</ItemGroup>
11+
12+
</Project>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<Project DefaultTargets="Pack">
2+
3+
<Import Project="Microsoft.ML.TensorFlow.nupkgproj" />
4+
5+
</Project>

src/Directory.Build.props

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212
<WarningsNotAsErrors>$(WarningsNotAsErrors);1591</WarningsNotAsErrors>
1313

1414
<CodeAnalysisRuleSet>$(MSBuildThisFileDirectory)\Source.ruleset</CodeAnalysisRuleSet>
15+
16+
<TargetArchitecture Condition="'$(TargetArchitecture)' == ''">x64</TargetArchitecture>
17+
18+
<NativeAssetsBuiltPath>$(BaseOutputPath)$(TargetArchitecture).$(Configuration)\Native</NativeAssetsBuiltPath>
19+
20+
<PackageRid Condition="'$(OS)' == 'Windows_NT'">win</PackageRid>
21+
<PackageRid Condition="'$(OS)' != 'Windows_NT'">linux</PackageRid>
22+
<PackageRid Condition="$([MSBuild]::IsOSPlatform('osx'))">osx</PackageRid>
23+
<PackageRid>$(PackageRid)-$(TargetArchitecture)</PackageRid>
1524
</PropertyGroup>
1625

1726
<ItemGroup>
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<TargetFramework>netstandard2.0</TargetFramework>
5+
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
6+
<DefineConstants>CORECLR</DefineConstants>
7+
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
8+
</PropertyGroup>
9+
10+
<ItemGroup>
11+
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
12+
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
13+
</ItemGroup>
14+
15+
<ItemGroup>
16+
<Compile Update="TensorFlow\TensorGeneric.cs">
17+
<DesignTime>True</DesignTime>
18+
<AutoGen>True</AutoGen>
19+
<DependentUpon>TensorGeneric.tt</DependentUpon>
20+
</Compile>
21+
<None Update="TensorFlow\TensorGeneric.tt">
22+
<Generator>TextTemplatingFileGenerator</Generator>
23+
<LastGenOutput>TensorGeneric.cs</LastGenOutput>
24+
</None>
25+
</ItemGroup>
26+
27+
<ItemGroup>
28+
<Service Include="{508349b6-6b84-4df5-91f0-309beebad82d}" />
29+
</ItemGroup>
30+
31+
</Project>
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using System;
6+
using System.Runtime.InteropServices;
7+
using System.Text;
8+
using size_t = System.UIntPtr;
9+
10+
#pragma warning disable MSML_GeneralName
11+
#pragma warning disable MSML_ParameterLocalVarName
12+
13+
namespace Microsoft.ML.Transforms.TensorFlow
14+
{
15+
/// <summary>
16+
/// This attribute can be applied to callback functions that will be invoked
17+
/// from unmanaged code to managed code.
18+
/// </summary>
19+
/// <remarks>
20+
/// <code>
21+
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
22+
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
23+
/// </code>
24+
/// </remarks>
25+
internal sealed class MonoPInvokeCallbackAttribute : Attribute
26+
{
27+
/// <summary>
28+
/// Use this constructor to annotate the type of the callback function that
29+
/// will be invoked from unmanaged code.
30+
/// </summary>
31+
/// <param name="t">T.</param>
32+
public MonoPInvokeCallbackAttribute(Type t) { }
33+
}
34+
35+
[StructLayout(LayoutKind.Sequential)]
36+
internal struct LLBuffer
37+
{
38+
internal IntPtr data;
39+
internal size_t length;
40+
internal IntPtr data_deallocator;
41+
}
42+
43+
/// <summary>
44+
/// Holds a block of data, suitable to pass, or retrieve from TensorFlow.
45+
/// </summary>
46+
/// <remarks>
47+
/// <para>
48+
/// Use the TFBuffer to blobs of data into TensorFlow, or to retrieve blocks
49+
/// of data out of TensorFlow.
50+
/// </para>
51+
/// <para>
52+
/// There are two constructors to wrap existing data, one to wrap blocks that are
53+
/// pointed to by an IntPtr and one that takes a byte array that we want to wrap.
54+
/// </para>
55+
/// <para>
56+
/// The empty constructor can be used to create a new TFBuffer that can be populated
57+
/// by the TensorFlow library and returned to user code.
58+
/// </para>
59+
/// <para>
60+
/// Typically, the data consists of a serialized protocol buffer, but other data
61+
/// may also be held in a buffer.
62+
/// </para>
63+
/// </remarks>
64+
// TODO: the string ctor
65+
// TODO: perhaps we should have an implicit byte [] conversion that just calls ToArray?
66+
internal class TFBuffer : TFDisposable
67+
{
68+
// extern TF_Buffer * TF_NewBufferFromString (const void *proto, size_t proto_len);
69+
[DllImport(NativeBinding.TensorFlowLibrary)]
70+
private static extern unsafe LLBuffer* TF_NewBufferFromString(IntPtr proto, IntPtr proto_len);
71+
72+
// extern TF_Buffer * TF_NewBuffer ();
73+
[DllImport(NativeBinding.TensorFlowLibrary)]
74+
private static extern unsafe LLBuffer* TF_NewBuffer();
75+
76+
internal TFBuffer(IntPtr handle) : base(handle) { }
77+
78+
/// <summary>
79+
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> class.
80+
/// </summary>
81+
public unsafe TFBuffer() : base((IntPtr)TF_NewBuffer())
82+
{
83+
}
84+
85+
/// <summary>
86+
/// Signature of the method that is invoked to release the data.
87+
/// </summary>
88+
/// <remarks>
89+
/// Methods of this signature are invoked with the data pointer and the
90+
/// lenght pointer when then TFBuffer no longer needs to hold on to the
91+
/// data. If you are using this on platforms with static compilation
92+
/// like iOS, you need to annotate your callback with the MonoPInvokeCallbackAttribute,
93+
/// like this:
94+
///
95+
/// <code>
96+
/// [TensorFlow.MonoPInvokeCallback (typeof (BufferReleaseFunc))]
97+
/// internal static void MyFreeFunc (IntPtr data, IntPtr length){..}
98+
/// </code>
99+
/// </remarks>
100+
public delegate void BufferReleaseFunc(IntPtr data, IntPtr lenght);
101+
102+
/// <summary>
103+
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> by wrapping the unmanaged resource pointed by the buffer.
104+
/// </summary>
105+
/// <param name="buffer">Pointer to the data that will be wrapped.</param>
106+
/// <param name="size">The size of the buffer to wrap.</param>
107+
/// <param name="release">Optional, if not null, this method will be invoked to release the block.</param>
108+
/// <remarks>
109+
/// This constructor wraps the buffer as a the data to be held by the <see cref="T:TensorFlow.TFBuffer"/>,
110+
/// if the release parameter is null, then you must ensure that the data is not released before the TFBuffer
111+
/// is no longer in use. If the value is not null, the provided method will be invoked to release
112+
/// the data when the TFBuffer is disposed, or the contents of the buffer replaced.
113+
/// </remarks>
114+
public unsafe TFBuffer(IntPtr buffer, long size, BufferReleaseFunc release) : base((IntPtr)TF_NewBuffer())
115+
{
116+
LLBuffer* buf = (LLBuffer*)handle;
117+
buf->data = buffer;
118+
buf->length = (size_t)size;
119+
if (release == null)
120+
buf->data_deallocator = IntPtr.Zero;
121+
else
122+
buf->data_deallocator = Marshal.GetFunctionPointerForDelegate(release);
123+
}
124+
125+
[MonoPInvokeCallback(typeof(BufferReleaseFunc))]
126+
internal static void FreeBlock(IntPtr data, IntPtr length)
127+
{
128+
Marshal.FreeHGlobal(data);
129+
}
130+
131+
internal static IntPtr FreeBufferFunc;
132+
internal static BufferReleaseFunc FreeBlockDelegate;
133+
134+
static TFBuffer()
135+
{
136+
FreeBlockDelegate = FreeBlock;
137+
FreeBufferFunc = Marshal.GetFunctionPointerForDelegate<BufferReleaseFunc>(FreeBlockDelegate);
138+
}
139+
140+
/// <summary>
141+
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> by making a copy of the provided byte array.
142+
/// </summary>
143+
/// <param name="buffer">Buffer of data that will be wrapped.</param>
144+
/// <remarks>
145+
/// This constructor makes a copy of the data into an unmanaged buffer,
146+
/// so the byte array is not pinned.
147+
/// </remarks>
148+
public TFBuffer(byte[] buffer) : this(buffer, 0, buffer.Length) { }
149+
150+
/// <summary>
151+
/// Initializes a new instance of the <see cref="T:TensorFlow.TFBuffer"/> by making a copy of the provided byte array.
152+
/// </summary>
153+
/// <param name="buffer">Buffer of data that will be wrapped.</param>
154+
/// <param name="start">Starting offset into the buffer to wrap.</param>
155+
/// <param name="count">Number of bytes from the buffer to keep.</param>
156+
/// <remarks>
157+
/// This constructor makes a copy of the data into an unmanaged buffer,
158+
/// so the byte array is not pinned.
159+
/// </remarks>
160+
public TFBuffer(byte[] buffer, int start, int count) : this()
161+
{
162+
if (start < 0 || start >= buffer.Length)
163+
throw new ArgumentException("start");
164+
if (count < 0 || count > buffer.Length - start)
165+
throw new ArgumentException("count");
166+
unsafe
167+
{
168+
LLBuffer* buf = LLBuffer;
169+
buf->data = Marshal.AllocHGlobal(count);
170+
Marshal.Copy(buffer, start, buf->data, count);
171+
buf->length = (size_t)count;
172+
buf->data_deallocator = FreeBufferFunc;
173+
}
174+
}
175+
176+
internal unsafe LLBuffer* LLBuffer => (LLBuffer*)handle;
177+
178+
// extern void TF_DeleteBuffer (TF_Buffer *);
179+
[DllImport(NativeBinding.TensorFlowLibrary)]
180+
private static extern unsafe void TF_DeleteBuffer(LLBuffer* buffer);
181+
182+
internal override void NativeDispose(IntPtr handle)
183+
{
184+
unsafe { TF_DeleteBuffer((LLBuffer*)handle); }
185+
}
186+
187+
// extern TF_Buffer TF_GetBuffer (TF_Buffer *buffer);
188+
[DllImport(NativeBinding.TensorFlowLibrary)]
189+
private static extern unsafe LLBuffer TF_GetBuffer(LLBuffer* buffer);
190+
191+
/// <summary>
192+
/// Returns a byte array representing the data wrapped by this buffer.
193+
/// </summary>
194+
/// <returns>The array.</returns>
195+
public byte[] ToArray()
196+
{
197+
if (handle == IntPtr.Zero)
198+
return null;
199+
200+
unsafe
201+
{
202+
var lb = (LLBuffer*)handle;
203+
204+
var result = new byte[(int)lb->length];
205+
Marshal.Copy(lb->data, result, 0, (int)lb->length);
206+
207+
return result;
208+
}
209+
}
210+
}
211+
}

0 commit comments

Comments
 (0)