|
2 | 2 | // The .NET Foundation licenses this file to you under the MIT license.
|
3 | 3 | // See the LICENSE file in the project root for more information.
|
4 | 4 |
|
| 5 | +using Microsoft.ML.Runtime; |
| 6 | +using Microsoft.ML.Runtime.Data; |
| 7 | +using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; |
| 8 | +using Microsoft.ML.Runtime.Internal.Utilities; |
5 | 9 | using System;
|
6 | 10 | using System.Collections.Generic;
|
7 | 11 | using System.IO;
|
8 | 12 | using System.Linq;
|
9 | 13 | using System.Runtime.InteropServices;
|
10 |
| -using Microsoft.ML.Runtime; |
11 |
| -using Microsoft.ML.Runtime.Data; |
12 |
| -using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; |
13 |
| -using Microsoft.ML.Runtime.Internal.Utilities; |
| 14 | +using System.Security.AccessControl; |
| 15 | +using System.Security.Principal; |
14 | 16 |
|
15 | 17 | namespace Microsoft.ML.Transforms.TensorFlow
|
16 | 18 | {
|
@@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte
|
158 | 160 | return new TFSession(graph);
|
159 | 161 | }
|
160 | 162 |
|
| 163 | + private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel) |
| 164 | + { |
| 165 | + Contracts.Check(env != null, nameof(env)); |
| 166 | + env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel)); |
| 167 | + var sessionOptions = new TFSessionOptions(); |
| 168 | + var tags = new string[] { "serve" }; |
| 169 | + var graph = new TFGraph(); |
| 170 | + var metaGraphDef = new TFBuffer(); |
| 171 | + |
| 172 | + return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef); |
| 173 | + } |
| 174 | + |
| 175 | + // A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure. |
| 176 | + // Given a modelPath, this utility method determines if we should treat it as a SavedModel or not |
| 177 | + internal static bool IsSavedModel(IHostEnvironment env, string modelPath) |
| 178 | + { |
| 179 | + Contracts.Check(env != null, nameof(env)); |
| 180 | + env.CheckNonWhiteSpace(modelPath, nameof(modelPath)); |
| 181 | + FileAttributes attr = File.GetAttributes(modelPath); |
| 182 | + return attr.HasFlag(FileAttributes.Directory); |
| 183 | + } |
| 184 | + |
| 185 | + // Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format. |
| 186 | + // Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it). |
| 187 | + /// <summary> |
| 188 | + /// Given a folder path, create it with proper ACL if it doesn't exist. |
| 189 | + /// Fails if the folder name is empty, or can't create the folder. |
| 190 | + /// </summary> |
| 191 | + internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder) |
| 192 | + { |
| 193 | + Contracts.Check(env != null, nameof(env)); |
| 194 | + env.CheckNonWhiteSpace(folder, nameof(folder)); |
| 195 | + |
| 196 | + //if directory exists, do nothing. |
| 197 | + if (Directory.Exists(folder)) |
| 198 | + return; |
| 199 | + |
| 200 | + WindowsIdentity currentIdentity = null; |
| 201 | + try |
| 202 | + { |
| 203 | + currentIdentity = WindowsIdentity.GetCurrent(); |
| 204 | + } |
| 205 | + catch (PlatformNotSupportedException) |
| 206 | + { } |
| 207 | + |
| 208 | + if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator)) |
| 209 | + { |
| 210 | + // Create high integrity dir and set no delete policy for all files under the directory. |
| 211 | + // In case of failure, throw exception. |
| 212 | + CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString()); |
| 213 | + } |
| 214 | + else |
| 215 | + { |
| 216 | + try |
| 217 | + { |
| 218 | + Directory.CreateDirectory(folder); |
| 219 | + } |
| 220 | + catch (Exception exc) |
| 221 | + { |
| 222 | + throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}"); |
| 223 | + } |
| 224 | + } |
| 225 | + } |
| 226 | + |
| 227 | + internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder) |
| 228 | + { |
| 229 | + Contracts.Check(env != null, nameof(env)); |
| 230 | + int currentRetry = 0; |
| 231 | + int maxRetryCount = 10; |
| 232 | + using (var ch = env.Start("Delete folder")) |
| 233 | + { |
| 234 | + for (; ; ) |
| 235 | + { |
| 236 | + try |
| 237 | + { |
| 238 | + currentRetry++; |
| 239 | + Directory.Delete(folder, true); |
| 240 | + break; |
| 241 | + } |
| 242 | + catch (IOException e) |
| 243 | + { |
| 244 | + if (currentRetry > maxRetryCount) |
| 245 | + throw; |
| 246 | + ch.Info("Error deleting folder. {0}. Retry,", e.Message); |
| 247 | + } |
| 248 | + } |
| 249 | + } |
| 250 | + } |
| 251 | + |
| 252 | + private static void CreateTempDirectoryWithAcl(string folder, string identity) |
| 253 | + { |
| 254 | + // Dacl Sddl string: |
| 255 | + // D: Dacl type |
| 256 | + // D; Deny access |
| 257 | + // OI; Object inherit ace |
| 258 | + // SD; Standard delete function |
| 259 | + // wIdentity.User Sid of the given user. |
| 260 | + // A; Allow access |
| 261 | + // OICI; Object inherit, container inherit |
| 262 | + // FA File access |
| 263 | + // BA Built-in administrators |
| 264 | + // S: Sacl type |
| 265 | + // ML;; Mandatory Label |
| 266 | + // NW;;; No write policy |
| 267 | + // HI High integrity processes only |
| 268 | + string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)"; |
| 269 | + |
| 270 | + try |
| 271 | + { |
| 272 | + var dir = Directory.CreateDirectory(folder); |
| 273 | + DirectorySecurity dirSec = new DirectorySecurity(); |
| 274 | + dirSec.SetSecurityDescriptorSddlForm(sddl); |
| 275 | + dirSec.SetAccessRuleProtection(true, false); // disable inheritance |
| 276 | + dir.SetAccessControl(dirSec); |
| 277 | + |
| 278 | + // Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL. |
| 279 | + DirectoryInfo dirInfo = new DirectoryInfo(folder); |
| 280 | + foreach (FileInfo file in dirInfo.GetFiles()) |
| 281 | + { |
| 282 | + file.Delete(); |
| 283 | + } |
| 284 | + foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories()) |
| 285 | + { |
| 286 | + subDirInfo.Delete(true); |
| 287 | + } |
| 288 | + } |
| 289 | + catch (Exception exc) |
| 290 | + { |
| 291 | + throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}"); |
| 292 | + } |
| 293 | + } |
| 294 | + |
| 295 | + internal static TFSession GetSession(IHostEnvironment env, string modelPath) |
| 296 | + { |
| 297 | + Contracts.Check(env != null, nameof(env)); |
| 298 | + if (IsSavedModel(env, modelPath)) |
| 299 | + { |
| 300 | + env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath)); |
| 301 | + return LoadTFSession(env, modelPath); |
| 302 | + } |
| 303 | + |
| 304 | + env.CheckUserArg(File.Exists(modelPath), nameof(modelPath)); |
| 305 | + var bytes = File.ReadAllBytes(modelPath); |
| 306 | + return LoadTFSession(env, bytes, modelPath); |
| 307 | + } |
| 308 | + |
161 | 309 | internal static unsafe void FetchData<T>(IntPtr data, T[] result)
|
162 | 310 | {
|
163 | 311 | var size = result.Length;
|
|
0 commit comments