Skip to content

Commit 893a385

Browse files
authored
Enable TensorFlowTransform to work with pre-trained models that are not frozen (dotnet#853)
* building transform from ground up * dummy transform works after fixing the getters * SavedModel format works for Train, but fails for Save&Predict * remove dummy transform * remove dummy unit test * Works with non-frozen models * building transform from ground up * dummy transform works after fixing the getters * SavedModel format works for Train, but fails for Save&Predict * remove dummy transform * remove dummy unit test * fix compilation issues; verify existing tests work fine * works locally; need to refactor code * refactored code; keeping only 1 version of the convenience API * added class for directory structure * using latest nuget package (0.0.3) for Microsoft.ML.TensorFlow.TestModels * delete temporary files used when loading/saving models * delete local models; the updated nuget version (0.0.3) for Microsoft.ML.TensorFlow.TestModels contains these models * modified logic for load/restore of models * modified logic for load&restore of unfrozen models * model version update to support non-frozen models * based on the code review comments, we now infer if the provided model is frozen or not * simplify the logic in Save() related to loading of SavedModel. * trying Eric's suggestion * revert back to previous changes * attempt to use stream copy approach instead of in-memory * deleting some commented out code * Ensure we only copy the file segment & cleanup path logic * added finalizer that closes the session (if it isn't closed) and deletes the temporary directory * cleanup + misc review comments * trying to create temp dir with proper ACLs for high priviledge users * create temp dir with proper ACLs for high-privilege processes * fix build after merge with latest master * taking care of review comments related to model versioning of TFTransform * remove IDisposable from the TensorFlowTransform; renaming some methods * refactor code so we have only 1 constructor for the TensorFlowTransform (as suggested in review comment) * fix issues with nuget packaging; refactored the code + added comments * add checks in code to make sure that the input is not a variable length vector * fix typo in name of package * (1) added SavedModel test for MNIST model (2) added try/finally for deleting temp folder (3) deleted test using Legacy Learning API * remove and sort usings in file TrainSaveModelAndPredict.cs * using spaces in nupkgproj * error checking for passed in IHostEnvironment * fix TargetFramework version (netcore 2.0) of DnnAnalyzer to match that of Microsoft.ML.TensorFlow
1 parent a02807c commit 893a385

File tree

15 files changed

+483
-85
lines changed

15 files changed

+483
-85
lines changed

build/Dependencies.props

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,7 @@
1717
<MicrosoftCodeAnalysisCSharpVersion>2.9.0</MicrosoftCodeAnalysisCSharpVersion>
1818
<MicrosoftCSharpVersion>4.5.0</MicrosoftCSharpVersion>
1919
<SystemCompositionVersion>1.2.0</SystemCompositionVersion>
20+
<SystemIOFileSystemAccessControl>4.5.0</SystemIOFileSystemAccessControl>
21+
<SystemSecurityPrincipalWindows>4.5.0</SystemSecurityPrincipalWindows>
2022
</PropertyGroup>
2123
</Project>

pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9-
<ProjectReference Include="..\Microsoft.ML.TensorFlow.Redist\Microsoft.ML.TensorFlow.Redist.nupkgproj" />
9+
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
10+
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
11+
<ProjectReference Include="../Microsoft.ML/Microsoft.ML.nupkgproj" />
12+
<ProjectReference Include="../Microsoft.ML.TensorFlow.Redist/Microsoft.ML.TensorFlow.Redist.nupkgproj" />
1013
</ItemGroup>
1114

1215
</Project>

src/Microsoft.ML.Core/Utilities/Stream.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,30 @@ public static void CloseEx(this TextWriter writer)
3030
writer.Close();
3131
}
3232

33+
/// <summary>
34+
/// Similar to Stream.CopyTo but takes a length rather than assuming copy to end. Returns amount copied.
35+
/// </summary>
36+
/// <param name="source">Source stream to copy from</param>
37+
/// <param name="destination">Destination stream to copy to</param>
38+
/// <param name="length">Number of bytes to copy</param>
39+
/// <param name="bufferSize">Size of buffer to use when copying, default is 81920 to match that of Stream</param>
40+
/// <returns>number of bytes copied</returns>
41+
public static long CopyRange(this Stream source, Stream destination, long length, int bufferSize = 81920)
42+
{
43+
// should use ArrayPool once we can take that dependency
44+
byte[] buffer = new byte[bufferSize];
45+
int read;
46+
long remaining = length;
47+
while (remaining != 0 &&
48+
(read = source.Read(buffer, 0, (int)Math.Min(buffer.Length, remaining))) != 0)
49+
{
50+
destination.Write(buffer, 0, read);
51+
remaining -= read;
52+
}
53+
54+
return length - remaining;
55+
}
56+
3357
public static void WriteBoolByte(this BinaryWriter writer, bool x)
3458
{
3559
Contracts.AssertValue(writer);

src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
<PropertyGroup>
44
<OutputType>Exe</OutputType>
5-
<TargetFramework>netcoreapp2.1</TargetFramework>
5+
<TargetFramework>netcoreapp2.0</TargetFramework>
66
<AssemblyName>DnnAnalyzer</AssemblyName>
77
<IncludeInPackage>Microsoft.ML.TensorFlow</IncludeInPackage>
88
</PropertyGroup>

src/Microsoft.ML.Legacy/CSharpApi.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15787,9 +15787,9 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
1578715787

1578815788

1578915789
/// <summary>
15790-
/// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
15790+
/// TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
1579115791
/// </summary>
15792-
public string ModelFile { get; set; }
15792+
public string Model { get; set; }
1579315793

1579415794
/// <summary>
1579515795
/// The names of the model inputs

src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
88
</PropertyGroup>
99

10+
<ItemGroup>
11+
<PackageReference Include="System.IO.FileSystem.AccessControl" Version="$(SystemIOFileSystemAccessControl)" />
12+
<PackageReference Include="System.Security.Principal.Windows" Version="$(SystemSecurityPrincipalWindows)" />
13+
</ItemGroup>
14+
1015
<ItemGroup>
1116
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
1217
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />

src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,7 @@ public IEnumerable<DeviceAttributes> ListDevices(TFStatus status = null)
11821182
/// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
11831183
/// </para>
11841184
/// </remarks>
1185-
public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
1185+
public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
11861186
{
11871187
if (graph == null)
11881188
throw new ArgumentNullException(nameof(graph));

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 152 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
// The .NET Foundation licenses this file to you under the MIT license.
33
// See the LICENSE file in the project root for more information.
44

5+
using Microsoft.ML.Runtime;
6+
using Microsoft.ML.Runtime.Data;
7+
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
8+
using Microsoft.ML.Runtime.Internal.Utilities;
59
using System;
610
using System.Collections.Generic;
711
using System.IO;
812
using System.Linq;
913
using System.Runtime.InteropServices;
10-
using Microsoft.ML.Runtime;
11-
using Microsoft.ML.Runtime.Data;
12-
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
13-
using Microsoft.ML.Runtime.Internal.Utilities;
14+
using System.Security.AccessControl;
15+
using System.Security.Principal;
1416

1517
namespace Microsoft.ML.Transforms.TensorFlow
1618
{
@@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte
158160
return new TFSession(graph);
159161
}
160162

163+
private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel)
164+
{
165+
Contracts.Check(env != null, nameof(env));
166+
env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel));
167+
var sessionOptions = new TFSessionOptions();
168+
var tags = new string[] { "serve" };
169+
var graph = new TFGraph();
170+
var metaGraphDef = new TFBuffer();
171+
172+
return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef);
173+
}
174+
175+
// A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure.
176+
// Given a modelPath, this utility method determines if we should treat it as a SavedModel or not
177+
internal static bool IsSavedModel(IHostEnvironment env, string modelPath)
178+
{
179+
Contracts.Check(env != null, nameof(env));
180+
env.CheckNonWhiteSpace(modelPath, nameof(modelPath));
181+
FileAttributes attr = File.GetAttributes(modelPath);
182+
return attr.HasFlag(FileAttributes.Directory);
183+
}
184+
185+
// Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format.
186+
// Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it).
187+
/// <summary>
188+
/// Given a folder path, create it with proper ACL if it doesn't exist.
189+
/// Fails if the folder name is empty, or can't create the folder.
190+
/// </summary>
191+
internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder)
192+
{
193+
Contracts.Check(env != null, nameof(env));
194+
env.CheckNonWhiteSpace(folder, nameof(folder));
195+
196+
//if directory exists, do nothing.
197+
if (Directory.Exists(folder))
198+
return;
199+
200+
WindowsIdentity currentIdentity = null;
201+
try
202+
{
203+
currentIdentity = WindowsIdentity.GetCurrent();
204+
}
205+
catch (PlatformNotSupportedException)
206+
{ }
207+
208+
if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator))
209+
{
210+
// Create high integrity dir and set no delete policy for all files under the directory.
211+
// In case of failure, throw exception.
212+
CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString());
213+
}
214+
else
215+
{
216+
try
217+
{
218+
Directory.CreateDirectory(folder);
219+
}
220+
catch (Exception exc)
221+
{
222+
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
223+
}
224+
}
225+
}
226+
227+
internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder)
228+
{
229+
Contracts.Check(env != null, nameof(env));
230+
int currentRetry = 0;
231+
int maxRetryCount = 10;
232+
using (var ch = env.Start("Delete folder"))
233+
{
234+
for (; ; )
235+
{
236+
try
237+
{
238+
currentRetry++;
239+
Directory.Delete(folder, true);
240+
break;
241+
}
242+
catch (IOException e)
243+
{
244+
if (currentRetry > maxRetryCount)
245+
throw;
246+
ch.Info("Error deleting folder. {0}. Retry,", e.Message);
247+
}
248+
}
249+
}
250+
}
251+
252+
private static void CreateTempDirectoryWithAcl(string folder, string identity)
253+
{
254+
// Dacl Sddl string:
255+
// D: Dacl type
256+
// D; Deny access
257+
// OI; Object inherit ace
258+
// SD; Standard delete function
259+
// wIdentity.User Sid of the given user.
260+
// A; Allow access
261+
// OICI; Object inherit, container inherit
262+
// FA File access
263+
// BA Built-in administrators
264+
// S: Sacl type
265+
// ML;; Mandatory Label
266+
// NW;;; No write policy
267+
// HI High integrity processes only
268+
string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)";
269+
270+
try
271+
{
272+
var dir = Directory.CreateDirectory(folder);
273+
DirectorySecurity dirSec = new DirectorySecurity();
274+
dirSec.SetSecurityDescriptorSddlForm(sddl);
275+
dirSec.SetAccessRuleProtection(true, false); // disable inheritance
276+
dir.SetAccessControl(dirSec);
277+
278+
// Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL.
279+
DirectoryInfo dirInfo = new DirectoryInfo(folder);
280+
foreach (FileInfo file in dirInfo.GetFiles())
281+
{
282+
file.Delete();
283+
}
284+
foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories())
285+
{
286+
subDirInfo.Delete(true);
287+
}
288+
}
289+
catch (Exception exc)
290+
{
291+
throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
292+
}
293+
}
294+
295+
internal static TFSession GetSession(IHostEnvironment env, string modelPath)
296+
{
297+
Contracts.Check(env != null, nameof(env));
298+
if (IsSavedModel(env, modelPath))
299+
{
300+
env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath));
301+
return LoadTFSession(env, modelPath);
302+
}
303+
304+
env.CheckUserArg(File.Exists(modelPath), nameof(modelPath));
305+
var bytes = File.ReadAllBytes(modelPath);
306+
return LoadTFSession(env, bytes, modelPath);
307+
}
308+
161309
internal static unsafe void FetchData<T>(IntPtr data, T[] result)
162310
{
163311
var size = result.Length;

0 commit comments

Comments
 (0)