Skip to content

Commit 2974c38

Browse files
committed
Merge remote-tracking branch 'upstream/master' into PCAEstimator
2 parents 83b1daf + 330aa41 commit 2974c38

File tree

15 files changed

+490
-85
lines changed

15 files changed

+490
-85
lines changed

build/Dependencies.props

+2
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@
1717
<MicrosoftCodeAnalysisCSharpVersion>2.9.0</MicrosoftCodeAnalysisCSharpVersion>
1818
<MicrosoftCSharpVersion>4.5.0</MicrosoftCSharpVersion>
1919
<SystemCompositionVersion>1.2.0</SystemCompositionVersion>
20+
<SystemIOFileSystemAccessControl>4.5.0</SystemIOFileSystemAccessControl>
21+
<SystemSecurityPrincipalWindows>4.5.0</SystemSecurityPrincipalWindows>
2022
</PropertyGroup>
2123
</Project>

pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj

+4-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9-
<ProjectReference Include="..\Microsoft.ML.TensorFlow.Redist\Microsoft.ML.TensorFlow.Redist.nupkgproj" />
9+
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
10+
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
11+
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
12+
<ProjectReference Include="../Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj" />
1013
</ItemGroup>
1114

1215
</Project>

src/Microsoft.ML.Core/Utilities/Stream.cs

+24
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@ public static void CloseEx(this TextWriter writer)
3030
writer.Close();
3131
}
3232

33+
/// <summary>
34+
/// Similar to Stream.CopyTo but takes a length rather than assuming copy to end. Returns amount copied.
35+
/// </summary>
36+
/// <param name="source">Source stream to copy from</param>
37+
/// <param name="destination">Destination stream to copy to</param>
38+
/// <param name="length">Number of bytes to copy</param>
39+
/// <param name="bufferSize">Size of buffer to use when copying, default is 81920 to match that of Stream</param>
40+
/// <returns>number of bytes copied</returns>
41+
public static long CopyRange(this Stream source, Stream destination, long length, int bufferSize = 81920)
42+
{
43+
// should use ArrayPool once we can take that dependency
44+
byte[] buffer = new byte[bufferSize];
45+
int read;
46+
long remaining = length;
47+
while (remaining != 0 &&
48+
(read = source.Read(buffer, 0, (int)Math.Min(buffer.Length, remaining))) != 0)
49+
{
50+
destination.Write(buffer, 0, read);
51+
remaining -= read;
52+
}
53+
54+
return length - remaining;
55+
}
56+
3357
public static void WriteBoolByte(this BinaryWriter writer, bool x)
3458
{
3559
Contracts.AssertValue(writer);

src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFramework>netcoreapp2.1</TargetFramework>
5+
<TargetFramework>netcoreapp2.0</TargetFramework>
66
<AssemblyName>DnnAnalyzer</AssemblyName>
77
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
88
</PropertyGroup>

src/Microsoft.ML.Legacy/CSharpApi.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -15787,9 +15787,9 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
1578715787

1578815788

1578915789
/// <summary>
15790-
/// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
15790+
/// TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
1579115791
/// </summary>
15792-
public string ModelFile { get; set; }
15792+
public string Model { get; set; }
1579315793

1579415794
/// <summary>
1579515795
/// The names of the model inputs

src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
88
</PropertyGroup>
99

10+
<ItemGroup>
11+
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
12+
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
13+
</ItemGroup>
14+
1015
<ItemGroup>
1116
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1217
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />

src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ public IEnumerable<DeviceAttributes> ListDevices(TFStatus status = null)
11821182
/// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
11831183
/// </para>
11841184
/// </remarks>
1185-
public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
1185+
public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
11861186
{
11871187
if (graph == null)
11881188
throw new ArgumentNullException(nameof(graph));

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

+152-4
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
59
using System;
610
using System.Collections.Generic;
711
using System.IO;
812
using System.Linq;
913
using System.Runtime.InteropServices;
10-
using Microsoft.ML.Runtime;
11-
using Microsoft.ML.Runtime.Data;
12-
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
13-
using Microsoft.ML.Runtime.Internal.Utilities;
14+
using System.Security.AccessControl;
15+
using System.Security.Principal;
1416

1517
namespace Microsoft.ML.Transforms.TensorFlow
1618
{
@@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte
158160
return new TFSession(graph);
159161
}
160162

163+
private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel)
164+
{
165+
Contracts.Check(env != null, nameof(env));
166+
env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel));
167+
var sessionOptions = new TFSessionOptions();
168+
var tags = new string[] { "serve" };
169+
var graph = new TFGraph();
170+
var metaGraphDef = new TFBuffer();
171+
172+
return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef);
173+
}
174+
175+
// A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure.
176+
// Given a modelPath, this utility method determines if we should treat it as a SavedModel or not
177+
internal static bool IsSavedModel(IHostEnvironment env, string modelPath)
178+
{
179+
Contracts.Check(env != null, nameof(env));
180+
env.CheckNonWhiteSpace(modelPath, nameof(modelPath));
181+
FileAttributes attr = File.GetAttributes(modelPath);
182+
return attr.HasFlag(FileAttributes.Directory);
183+
}
184+
185+
// Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format.
186+
// Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it).
187+
/// <summary>
188+
/// Given a folder path, create it with proper ACL if it doesn't exist.
189+
/// Fails if the folder name is empty, or can't create the folder.
190+
/// </summary>
191+
internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder)
192+
{
193+
Contracts.Check(env != null, nameof(env));
194+
env.CheckNonWhiteSpace(folder, nameof(folder));
195+
196+
//if directory exists, do nothing.
197+
if (Directory.Exists(folder))
198+
return;
199+
200+
WindowsIdentity currentIdentity = null;
201+
try
202+
{
203+
currentIdentity = WindowsIdentity.GetCurrent();
204+
}
205+
catch (PlatformNotSupportedException)
206+
{ }
207+
208+
if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator))
209+
{
210+
// Create high integrity dir and set no delete policy for all files under the directory.
211+
// In case of failure, throw exception.
212+
CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString());
213+
}
214+
else
215+
{
216+
try
217+
{
218+
Directory.CreateDirectory(folder);
219+
}
220+
catch (Exception exc)
221+
{
222+
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
223+
}
224+
}
225+
}
226+
227+
internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder)
228+
{
229+
Contracts.Check(env != null, nameof(env));
230+
int currentRetry = 0;
231+
int maxRetryCount = 10;
232+
using (var ch = env.Start("Delete folder"))
233+
{
234+
for (; ; )
235+
{
236+
try
237+
{
238+
currentRetry++;
239+
Directory.Delete(folder, true);
240+
break;
241+
}
242+
catch (IOException e)
243+
{
244+
if (currentRetry > maxRetryCount)
245+
throw;
246+
ch.Info("Error deleting folder. {0}. Retry,", e.Message);
247+
}
248+
}
249+
}
250+
}
251+
252+
private static void CreateTempDirectoryWithAcl(string folder, string identity)
253+
{
254+
// Dacl Sddl string:
255+
// D: Dacl type
256+
// D; Deny access
257+
// OI; Object inherit ace
258+
// SD; Standard delete function
259+
// wIdentity.User Sid of the given user.
260+
// A; Allow access
261+
// OICI; Object inherit, container inherit
262+
// FA File access
263+
// BA Built-in administrators
264+
// S: Sacl type
265+
// ML;; Mandatory Label
266+
// NW;;; No write policy
267+
// HI High integrity processes only
268+
string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)";
269+
270+
try
271+
{
272+
var dir = Directory.CreateDirectory(folder);
273+
DirectorySecurity dirSec = new DirectorySecurity();
274+
dirSec.SetSecurityDescriptorSddlForm(sddl);
275+
dirSec.SetAccessRuleProtection(true, false); // disable inheritance
276+
dir.SetAccessControl(dirSec);
277+
278+
// Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL.
279+
DirectoryInfo dirInfo = new DirectoryInfo(folder);
280+
foreach (FileInfo file in dirInfo.GetFiles())
281+
{
282+
file.Delete();
283+
}
284+
foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories())
285+
{
286+
subDirInfo.Delete(true);
287+
}
288+
}
289+
catch (Exception exc)
290+
{
291+
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
292+
}
293+
}
294+
295+
internal static TFSession GetSession(IHostEnvironment env, string modelPath)
296+
{
297+
Contracts.Check(env != null, nameof(env));
298+
if (IsSavedModel(env, modelPath))
299+
{
300+
env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath));
301+
return LoadTFSession(env, modelPath);
302+
}
303+
304+
env.CheckUserArg(File.Exists(modelPath), nameof(modelPath));
305+
var bytes = File.ReadAllBytes(modelPath);
306+
return LoadTFSession(env, bytes, modelPath);
307+
}
308+
161309
internal static unsafe void FetchData<T>(IntPtr data, T[] result)
162310
{
163311
var size = result.Length;

0 commit comments

Comments
 (0)