Skip to content

Commit fd6b8d5

Browse files
authored
Merge pull request dotnet#3 from Oceania2018/tftransferlearning
Tftransferlearning
2 parents 5beab30 + f050f03 commit fd6b8d5

File tree

2 files changed

+4
-27
lines changed

2 files changed

+4
-27
lines changed

src/Microsoft.ML.Dnn/Microsoft.ML.Dnn.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
<ItemGroup>
1616
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
1717
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
18-
<PackageReference Include="TensorFlow.NET" Version="0.10.6" />
18+
<PackageReference Include="TensorFlow.NET" Version="0.10.7.2" />
1919
</ItemGroup>
2020

2121
<ItemGroup>

src/Microsoft.ML.Dnn/TensorflowUtils.cs

+3-26
Original file line numberDiff line numberDiff line change
@@ -291,34 +291,11 @@ internal static unsafe void FetchStringData<T>(Tensor tensor, Span<T> result)
291291
{
292292
if (tensor == null)
293293
throw Contracts.ExceptEmpty(nameof(tensor));
294-
//
295-
// TF_STRING tensors are encoded with a table of 8-byte offsets followed by TF_StringEncode-encoded bytes.
296-
// [offset1, offset2,...,offsetn, s1size, s1bytes, s2size, s2bytes,...,snsize,snbytes]
297-
//
298-
long size = 1;
299-
foreach (var s in tensor.TensorShape.Dimensions)
300-
size *= s;
301-
302-
var buffer = new byte[size][];
303-
var src = c_api.TF_TensorData(tensor);
304-
var srcLen = (IntPtr)(src.ToInt64() + (long)tensor.bytesize);
305-
src += (int)(size * 8);
306-
for (int i = 0; i < buffer.Length; i++)
307-
{
308-
using (var status = new Status())
309-
{
310-
IntPtr dst = IntPtr.Zero;
311-
ulong dstLen = 0;
312-
var read = c_api.TF_StringDecode(src, (ulong)(srcLen.ToInt64() - src.ToInt64()), dst, ref dstLen, status);
313-
status.Check();
314-
buffer[i] = new byte[(int)dstLen];
315-
Marshal.Copy(dst, buffer[i], 0, buffer[i].Length);
316-
src += (int)read;
317-
}
318-
}
294+
295+
var buffer = tensor.StringData();
319296

320297
for (int i = 0; i < buffer.Length; i++)
321-
result[i] = (T)(object)Encoding.UTF8.GetString(buffer[i]).AsMemory();
298+
result[i] = (T)(object)buffer[i].AsMemory();
322299
}
323300

324301
internal static bool IsTypeSupported(TF_DataType tfoutput)

0 commit comments

Comments
 (0)