diff --git a/build/Dependencies.props b/build/Dependencies.props index c1334615db..845d722820 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -17,5 +17,7 @@ 2.9.0 4.5.0 1.2.0 + 4.5.0 + 4.5.0 diff --git a/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj index fb07384f90..c762af0c75 100644 --- a/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj +++ b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj @@ -6,7 +6,10 @@ - + + + + diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs index 7fbf0148b9..80e18b6bf3 100644 --- a/src/Microsoft.ML.Core/Utilities/Stream.cs +++ b/src/Microsoft.ML.Core/Utilities/Stream.cs @@ -30,6 +30,30 @@ public static void CloseEx(this TextWriter writer) writer.Close(); } + /// + /// Similar to Stream.CopyTo but takes a length rather than assuming copy to end. Returns amount copied. + /// + /// Source stream to copy from + /// Destination stream to copy to + /// Number of bytes to copy + /// Size of buffer to use when copying, default is 81920 to match that of Stream + /// number of bytes copied + public static long CopyRange(this Stream source, Stream destination, long length, int bufferSize = 81920) + { + // should use ArrayPool once we can take that dependency + byte[] buffer = new byte[bufferSize]; + int read; + long remaining = length; + while (remaining != 0 && + (read = source.Read(buffer, 0, (int)Math.Min(buffer.Length, remaining))) != 0) + { + destination.Write(buffer, 0, read); + remaining -= read; + } + + return length - remaining; + } + public static void WriteBoolByte(this BinaryWriter writer, bool x) { Contracts.AssertValue(writer); diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj index 7c77ff2ffa..e50bed112c 100644 --- a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj +++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj @@ -2,7 +2,7 @@ Exe - netcoreapp2.1 + netcoreapp2.0 DnnAnalyzer Microsoft.ML.TensorFlow diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index ba6f0f866e..634e071f4e 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -15787,9 +15787,9 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints. /// - /// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details. + /// TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details. /// - public string ModelFile { get; set; } + public string Model { get; set; } /// /// The names of the model inputs diff --git a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj index a0ee42b93c..9c756460bc 100644 --- a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj +++ b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj @@ -7,6 +7,11 @@ true + + + + + diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs index e63e4f56c2..aba15ea04c 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs @@ -1182,7 +1182,7 @@ public IEnumerable ListDevices(TFStatus status = null) /// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md /// /// - public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null) + public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null) { if (graph == null) throw new ArgumentNullException(nameof(graph)); diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs index 54030aec91..281582e00a 100644 --- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs +++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs @@ -2,15 +2,17 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; +using Microsoft.ML.Runtime.Internal.Utilities; using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Runtime.InteropServices; -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints; -using Microsoft.ML.Runtime.Internal.Utilities; +using System.Security.AccessControl; +using System.Security.Principal; namespace Microsoft.ML.Transforms.TensorFlow { @@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte return new TFSession(graph); } + private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel) + { + Contracts.Check(env != null, nameof(env)); + env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel)); + var sessionOptions = new TFSessionOptions(); + var tags = new string[] { "serve" }; + var graph = new TFGraph(); + var metaGraphDef = new TFBuffer(); + + return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef); + } + + // A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure. + // Given a modelPath, this utility method determines if we should treat it as a SavedModel or not + internal static bool IsSavedModel(IHostEnvironment env, string modelPath) + { + Contracts.Check(env != null, nameof(env)); + env.CheckNonWhiteSpace(modelPath, nameof(modelPath)); + FileAttributes attr = File.GetAttributes(modelPath); + return attr.HasFlag(FileAttributes.Directory); + } + + // Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format. + // Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it). + /// + /// Given a folder path, create it with proper ACL if it doesn't exist. + /// Fails if the folder name is empty, or can't create the folder. + /// + internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder) + { + Contracts.Check(env != null, nameof(env)); + env.CheckNonWhiteSpace(folder, nameof(folder)); + + //if directory exists, do nothing. + if (Directory.Exists(folder)) + return; + + WindowsIdentity currentIdentity = null; + try + { + currentIdentity = WindowsIdentity.GetCurrent(); + } + catch (PlatformNotSupportedException) + { } + + if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator)) + { + // Create high integrity dir and set no delete policy for all files under the directory. + // In case of failure, throw exception. + CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString()); + } + else + { + try + { + Directory.CreateDirectory(folder); + } + catch (Exception exc) + { + throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}"); + } + } + } + + internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder) + { + Contracts.Check(env != null, nameof(env)); + int currentRetry = 0; + int maxRetryCount = 10; + using (var ch = env.Start("Delete folder")) + { + for (; ; ) + { + try + { + currentRetry++; + Directory.Delete(folder, true); + break; + } + catch (IOException e) + { + if (currentRetry > maxRetryCount) + throw; + ch.Info("Error deleting folder. {0}. Retry,", e.Message); + } + } + } + } + + private static void CreateTempDirectoryWithAcl(string folder, string identity) + { + // Dacl Sddl string: + // D: Dacl type + // D; Deny access + // OI; Object inherit ace + // SD; Standard delete function + // wIdentity.User Sid of the given user. + // A; Allow access + // OICI; Object inherit, container inherit + // FA File access + // BA Built-in administrators + // S: Sacl type + // ML;; Mandatory Label + // NW;;; No write policy + // HI High integrity processes only + string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)"; + + try + { + var dir = Directory.CreateDirectory(folder); + DirectorySecurity dirSec = new DirectorySecurity(); + dirSec.SetSecurityDescriptorSddlForm(sddl); + dirSec.SetAccessRuleProtection(true, false); // disable inheritance + dir.SetAccessControl(dirSec); + + // Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL. + DirectoryInfo dirInfo = new DirectoryInfo(folder); + foreach (FileInfo file in dirInfo.GetFiles()) + { + file.Delete(); + } + foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories()) + { + subDirInfo.Delete(true); + } + } + catch (Exception exc) + { + throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}"); + } + } + + internal static TFSession GetSession(IHostEnvironment env, string modelPath) + { + Contracts.Check(env != null, nameof(env)); + if (IsSavedModel(env, modelPath)) + { + env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath)); + return LoadTFSession(env, modelPath); + } + + env.CheckUserArg(File.Exists(modelPath), nameof(modelPath)); + var bytes = File.ReadAllBytes(modelPath); + return LoadTFSession(env, bytes, modelPath); + } + internal static unsafe void FetchData(IntPtr data, T[] result) { var size = result.Length; diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index acf3c4ec23..78a49fac1c 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; +using System.IO.Compression; using System.Linq; using Microsoft.ML.Core.Data; using Microsoft.ML.Data.StaticPipe.Runtime; @@ -38,9 +39,8 @@ public sealed class TensorFlowTransform : ITransformer, ICanSaveModel { public sealed class Arguments : TransformInputBase { - - [Argument(ArgumentType.Required, HelpText = "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", ShortName = "model", SortOrder = 0)] - public string ModelFile; + [Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)] + public string Model; [Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)] public string[] InputColumns; @@ -50,6 +50,8 @@ public sealed class Arguments : TransformInputBase } private readonly IHost _host; + private readonly string _savedModelPath; + private readonly bool _isTemporarySavedModel; private const string RegistrationName = "TensorFlowTransform"; internal readonly TFSession Session; @@ -73,7 +75,7 @@ private static VersionInfo GetVersionInfo() return new VersionInfo( modelSignature: "TENSFLOW", //verWrittenCur: 0x00010001, // Initial - verWrittenCur: 0x00010002, // Upgraded when change for multiple outputs was implemented. + verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel. verReadableCur: 0x00010002, verWeCanReadBack: 0x00010001, loaderSignature: LoaderSignature); @@ -84,25 +86,12 @@ private static VersionInfo GetVersionInfo() /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// This is the frozen TensorFlow model file. https://www.tensorflow.org/mobile/prepare_models - /// Name of the output column. Keep it same as in the TensorFlow model. - /// Name of the input column(s). Keep it same as in the TensorFlow model. - public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source) - { - return new TensorFlowTransform(env, modelFile, source, new[] { name }).MakeDataTransform(input); - } - - /// - /// Convenience constructor for public facing API. - /// - /// Host Environment. - /// Input . This is the output from previous transform or loader. - /// This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models + /// Path to the TensorFlow model. /// Name of the output column(s). Keep it same as in the Tensorflow model. /// Name of the input column(s). Keep it same as in the Tensorflow model. - public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] names, string[] source) + public static IDataTransform Create(IHostEnvironment env, IDataView input, string model, string[] names, string[] source) { - return new TensorFlowTransform(env, modelFile, source, names).MakeDataTransform(input); + return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, model), source, names, TensorFlowUtils.IsSavedModel(env, model) ? model : null, false).MakeDataTransform(input); } // Factory method for SignatureLoadModel. @@ -111,7 +100,9 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext Contracts.CheckValue(env, nameof(env)); env.CheckValue(ctx, nameof(ctx)); ctx.CheckAtModel(GetVersionInfo()); + // *** Binary format *** + // byte: indicator for frozen models // stream: tensorFlow model. // int: number of input columns // for each input column @@ -119,27 +110,48 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext // int: number of output columns // for each output column // int: id of output column name - byte[] modelBytes = null; - if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) - throw env.ExceptDecode(); - var session = TensorFlowUtils.LoadTFSession(env, modelBytes); - var numInputs = ctx.Reader.ReadInt32(); - env.CheckDecode(numInputs > 0); - string[] inputs = new string[numInputs]; - for (int j = 0; j < inputs.Length; j++) - inputs[j] = ctx.LoadNonEmptyString(); - - bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; - var numOutputs = 1; - if (isMultiOutput) - numOutputs = ctx.Reader.ReadInt32(); + GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen); + if (isFrozen) + { + byte[] modelBytes = null; + if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) + throw env.ExceptDecode(); + return new TensorFlowTransform(env, TensorFlowUtils.LoadTFSession(env, modelBytes), inputs, outputs, null, false); + } - env.CheckDecode(numOutputs > 0); - var outputs = new string[numOutputs]; - for (int j = 0; j < outputs.Length; j++) - outputs[j] = ctx.LoadNonEmptyString(); + var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), RegistrationName + "_" + Guid.NewGuid())); + TensorFlowUtils.CreateFolderWithAclIfNotExists(env, tempDirPath); + try + { + var load = ctx.TryLoadBinaryStream("TFSavedModel", br => + { + int count = br.ReadInt32(); + for (int n = 0; n < count; n++) + { + string relativeFile = br.ReadString(); + long fileLength = br.ReadInt64(); + + string fullFilePath = Path.Combine(tempDirPath, relativeFile); + string fullFileDir = Path.GetDirectoryName(fullFilePath); + if (fullFileDir != tempDirPath) + { + TensorFlowUtils.CreateFolderWithAclIfNotExists(env, fullFileDir); + } + using (var fs = new FileStream(fullFilePath, FileMode.Create, FileAccess.Write)) + { + long actualRead = br.BaseStream.CopyRange(fs, fileLength); + env.Assert(actualRead == fileLength); + } + } + }); - return new TensorFlowTransform(env, session, inputs, outputs); + return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, tempDirPath), inputs, outputs, tempDirPath, true); + } + catch (Exception) + { + TensorFlowUtils.DeleteFolderWithRetries(env, tempDirPath); + throw; + } } // Factory method for SignatureDataTransform. @@ -150,7 +162,8 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData env.CheckValue(input, nameof(input)); env.CheckValue(args.InputColumns, nameof(args.InputColumns)); env.CheckValue(args.OutputColumns, nameof(args.OutputColumns)); - return new TensorFlowTransform(env, args.ModelFile, args.InputColumns, args.OutputColumns).MakeDataTransform(input); + + return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, args.Model), args.InputColumns, args.OutputColumns, TensorFlowUtils.IsSavedModel(env, args.Model) ? args.Model : null, false).MakeDataTransform(input); } // Factory method for SignatureLoadDataTransform. @@ -161,27 +174,42 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile) + private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen) { - env.CheckNonWhiteSpace(modelFile, nameof(modelFile)); - env.CheckUserArg(File.Exists(modelFile), nameof(modelFile)); - var bytes = File.ReadAllBytes(modelFile); - return TensorFlowUtils.LoadTFSession(env, bytes, modelFile); - } + isFrozen = true; + bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002; + if (isNonFrozenModelSupported) + isFrozen = ctx.Reader.ReadBoolByte(); - public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : - this(env, CheckFileAndRead(env, modelFile), inputs, outputs) - { + var numInputs = ctx.Reader.ReadInt32(); + env.CheckDecode(numInputs > 0); + inputs = new string[numInputs]; + for (int j = 0; j < inputs.Length; j++) + inputs[j] = ctx.LoadNonEmptyString(); + + bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002; + var numOutputs = 1; + if (isMultiOutput) + numOutputs = ctx.Reader.ReadInt32(); + + env.CheckDecode(numOutputs > 0); + outputs = new string[numOutputs]; + for (int j = 0; j < outputs.Length; j++) + outputs[j] = ctx.LoadNonEmptyString(); } - private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs) + internal TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string savedModelPath, bool isTemporarySavedModel) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(RegistrationName)); _host.CheckValue(session, nameof(session)); _host.CheckNonEmpty(inputs, nameof(inputs)); _host.CheckNonEmpty(outputs, nameof(outputs)); + Session = session; + _savedModelPath = savedModelPath; + _isTemporarySavedModel = isTemporarySavedModel; + foreach (var input in inputs) { _host.CheckNonWhiteSpace(input, nameof(inputs)); @@ -261,7 +289,9 @@ public void Save(ModelSaveContext ctx) _host.AssertValue(ctx); ctx.CheckAtModel(); ctx.SetVersionInfo(GetVersionInfo()); + // *** Binary format *** + // byte: indicator for frozen models // stream: tensorFlow model. // int: number of input columns // for each input column @@ -269,14 +299,39 @@ public void Save(ModelSaveContext ctx) // int: number of output columns // for each output column // int: id of output column name - - var buffer = new TFBuffer(); - Session.Graph.ToGraphDef(buffer); - - ctx.SaveBinaryStream("TFModel", w => + var isFrozen = string.IsNullOrEmpty(_savedModelPath); + ctx.Writer.WriteBoolByte(isFrozen); + if (isFrozen) { - w.WriteByteArray(buffer.ToArray()); - }); + var buffer = new TFBuffer(); + Session.Graph.ToGraphDef(buffer); + ctx.SaveBinaryStream("TFModel", w => + { + w.WriteByteArray(buffer.ToArray()); + }); + } + else + { + ctx.SaveBinaryStream("TFSavedModel", w => + { + string[] modelFilePaths = Directory.GetFiles(_savedModelPath, "*", SearchOption.AllDirectories); + w.Write(modelFilePaths.Length); + + foreach (var fullPath in modelFilePaths) + { + var relativePath = fullPath.Substring(_savedModelPath.Length + 1); + w.Write(relativePath); + + using (var fs = new FileStream(fullPath, FileMode.Open)) + { + long fileLength = fs.Length; + w.Write(fileLength); + long actualWritten = fs.CopyRange(w.BaseStream, fileLength); + _host.Assert(actualWritten == fileLength); + } + } + }); + } _host.AssertNonEmpty(Inputs); ctx.Writer.Write(Inputs.Length); foreach (var colName in Inputs) @@ -288,6 +343,33 @@ public void Save(ModelSaveContext ctx) ctx.SaveNonEmptyString(colName); } + ~TensorFlowTransform() + { + Dispose(false); + } + + private void Dispose(bool disposing) + { + // Ensure that the Session is not null and it's handle is not Zero, as it may have already been disposed/finalized. + // Technically we shouldn't be calling this if disposing == false, since we're running in finalizer + // and the GC doesn't guarantee ordering of finalization of managed objects, but we have to make sure + // that the Session is closed before deleting our temporary directory. + try + { + if (Session?.Handle != IntPtr.Zero) + { + Session.CloseSession(); + Session.Dispose(); + } + } + finally + { + if (!string.IsNullOrEmpty(_savedModelPath) && _isTemporarySavedModel) + { + TensorFlowUtils.DeleteFolderWithRetries(_host, _savedModelPath); + } + } + } public bool IsRowToRowMapper => true; public IRowToRowMapper GetRowToRowMapper(ISchema inputSchema) @@ -322,6 +404,9 @@ public Mapper(IHostEnvironment env, TensorFlowTransform parent, ISchema inputSch throw _host.Except($"Column {_parent.Inputs[i]} doesn't exist"); var type = inputSchema.GetColumnType(_inputColIndices[i]); + if (type.IsVector && type.VectorSize == 0) + throw _host.Except($"Variable length input columns not supported"); + _isInputVector[i] = type.IsVector; var expectedType = TensorFlowUtils.Tf2MlNetType(_parent.TFInputTypes[i]); if (type.ItemType != expectedType) @@ -564,8 +649,8 @@ public static CommonOutputs.TransformOutput TensorFlowScorer(IHostEnvironment en public sealed class TensorFlowEstimator : TrivialEstimator { - public TensorFlowEstimator(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) - : this(env, new TensorFlowTransform(env, modelFile, inputs, outputs)) + public TensorFlowEstimator(IHostEnvironment env, string model, string[] inputs, string[] outputs) + : this(env, new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, model), inputs, outputs, TensorFlowUtils.IsSavedModel(env, model) ? model : null, false)) { } @@ -584,7 +669,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var input = Transformer.Inputs[i]; if (!inputSchema.TryFindColumn(input, out var col)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); - if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector)) + if (!(col.Kind == SchemaShape.Column.VectorKind.Vector)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString()); var expectedType = TensorFlowUtils.Tf2MlNetType(Transformer.TFInputTypes[i]); if (col.ItemType != expectedType) @@ -602,7 +687,6 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) public static class TensorFlowStaticExtensions { - private sealed class OutColumn : Vector { public PipelineColumn Input { get; } diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json index 1ff41a090b..03f8cb170f 100644 --- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json +++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json @@ -21729,12 +21729,9 @@ "ShortName": "TFTransform", "Inputs": [ { - "Name": "ModelFile", + "Name": "Model", "Type": "String", - "Desc": "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", - "Aliases": [ - "model" - ], + "Desc": "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", "Required": true, "SortOrder": 0.0, "IsNullable": false diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index b3d7068041..2e5292d13f 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -977,9 +977,9 @@ public void TestTensorFlowEntryPoint() var tfTransformInput = new Legacy.Transforms.TensorFlowScorer { Data = importOutput.Data, + Model = "mnist_model/frozen_saved_model.pb", InputColumns = new[] { "Placeholder" }, OutputColumns = new[] { "Softmax" }, - ModelFile = "mnist_model/frozen_saved_model.pb" }; var tfTransformOutput = experiment.Add(tfTransformInput); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index d674d99c61..b1e86c70a2 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -3733,7 +3733,7 @@ public void EntryPointTensorFlowTransform() new[] { @"'InputColumns': [ 'Placeholder' ], - 'ModelFile': 'mnist_model/frozen_saved_model.pb', + 'Model': 'mnist_model/frozen_saved_model.pb', 'OutputColumns': [ 'Softmax' ]" }); } diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs index 2c25745100..cfeed15bd5 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs @@ -36,7 +36,6 @@ public async void TrainSaveModelAndPredict() var loadedModel = await Legacy.PredictionModel.ReadAsync(modelName); var singlePrediction = loadedModel.Predict(new SentimentData() { SentimentText = "Not big fan of this." }); Assert.True(singlePrediction.Sentiment); - } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index 6a03f114b8..4685999813 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -44,7 +44,7 @@ public void TensorFlowTransformCifarLearningPipelineTest() pipeline.Add(new TensorFlowScorer() { - ModelFile = model_location, + Model = model_location, InputColumns = new[] { "Input" }, OutputColumns = new[] { "Output" } }); @@ -121,7 +121,7 @@ public void TensorFlowTransformInceptionPipelineTest() pipeline.Add(new TensorFlowScorer() { - ModelFile = model_location, + Model = model_location, InputColumns = new[] { inputTensorName }, OutputColumns = new[] { outputTensorName } }); diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 6ee86513bc..dfca4ebf6a 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -44,7 +44,7 @@ public void TensorFlowTransformMatrixMultiplicationTest() b = new[] { 3.0f, 3.0f, 3.0f, 3.0f } } })); - var trans = TensorFlowTransform.Create(env, loader, model_location, "c", "a", "b"); + var trans = TensorFlowTransform.Create(env, loader, model_location, new[] { "c" }, new[] { "a", "b" }); using (var cursor = trans.GetRowCursor(a => true)) { @@ -163,7 +163,7 @@ public void TensorFlowTransformInceptionTest() } }, cropped); - var tf = TensorFlowTransform.Create(env, pixels, model_location, "softmax2_pre_activation", "input"); + var tf = TensorFlowTransform.Create(env, pixels, model_location, new[] { "softmax2_pre_activation" }, new[] { "input" }); tf.Schema.TryGetColumnIndex("input", out int input); tf.Schema.TryGetColumnIndex("softmax2_pre_activation", out int b); @@ -353,6 +353,88 @@ public void TensorFlowTransformMNISTConvTest() } } + [Fact] + public void TensorFlowTransformMNISTConvSavedModelTest() + { + var model_location = "mnist_model"; + using (var env = new ConsoleEnvironment(seed: 1, conc: 1)) + { + var dataPath = GetDataPath("Train-Tiny-28x28.txt"); + var testDataPath = GetDataPath("MNIST.Test.tiny.txt"); + + // Pipeline + var loader = TextLoader.ReadFile(env, + new TextLoader.Arguments() + { + Separator = "tab", + HasHeader = true, + Column = new[] + { + new TextLoader.Column("Label", DataKind.Num,0), + new TextLoader.Column("Placeholder", DataKind.Num,new []{new TextLoader.Range(1, 784) }) + + } + }, new MultiFileSource(dataPath)); + + IDataView trans = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments() + { + Column = new[] { new CopyColumnsTransform.Column() + { Name = "reshape_input", Source = "Placeholder" } + } + }, loader); + trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }); + trans = new ConcatTransform(env, "Features", "Softmax", "dense/Relu").Transform(trans); + + var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features"); + + var cached = new CacheDataView(env, trans, prefetch: null); + var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features"); + var pred = trainer.Train(trainRoles); + + // Get scorer and evaluate the predictions from test data + IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath); + var metrics = Evaluate(env, testDataScorer); + + Assert.Equal(0.99, metrics.AccuracyMicro, 2); + Assert.Equal(1.0, metrics.AccuracyMacro, 2); + + // Create prediction engine and test predictions + var model = env.CreatePredictionEngine(testDataScorer); + + var sample1 = new MNISTData() + { + Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26, + 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253, + 225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253, + 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198, + 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0, + 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 } + }; + + var prediction = model.Predict(sample1); + + float max = -1; + int maxIndex = -1; + for (int i = 0; i < prediction.PredictedLabels.Length; i++) + { + if (prediction.PredictedLabels[i] > max) + { + max = prediction.PredictedLabels[i]; + maxIndex = i; + } + } + + Assert.Equal(5, maxIndex); + } + } + [Fact] public void TensorFlowTransformMNISTConvPipelineTest() { @@ -364,7 +446,7 @@ public void TensorFlowTransformMNISTConvPipelineTest() pipeline.Add(new Legacy.Transforms.ColumnCopier() { Column = new[] { new CopyColumnsTransformColumn() { Name = "reshape_input", Source = "Placeholder" } } }); pipeline.Add(new TensorFlowScorer() { - ModelFile = model_location, + Model = model_location, OutputColumns = new[] { "Softmax", "dense/Relu" }, InputColumns = new[] { "Placeholder", "reshape_input" } }); @@ -444,7 +526,61 @@ public void TensorFlowTransformCifar() }, cropped); - IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input"); + IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" }); + + trans.Schema.TryGetColumnIndex("Output", out int output); + using (var cursor = trans.GetRowCursor(col => col == output)) + { + var buffer = default(VBuffer); + var getter = cursor.GetGetter>(output); + var numRows = 0; + while (cursor.MoveNext()) + { + getter(ref buffer); + Assert.Equal(10, buffer.Length); + numRows += 1; + } + Assert.Equal(3, numRows); + } + } + } + + [Fact] + public void TensorFlowTransformCifarSavedModel() + { + var model_location = "cifar_saved_model"; + + using (var env = new ConsoleEnvironment()) + { + var imageHeight = 32; + var imageWidth = 32; + var dataFile = GetDataPath("images/images.tsv"); + var imageFolder = Path.GetDirectoryName(dataFile); + var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); + var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments() + { + Column = new ImageLoaderTransform.Column[1] + { + new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" } + }, + ImageFolder = imageFolder + }, data); + var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments() + { + Column = new ImageResizerTransform.Column[1]{ + new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop} + } + }, images); + + var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments() + { + Column = new ImagePixelExtractorTransform.Column[1]{ + new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true} + } + }, cropped); + + + IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" }); trans.Schema.TryGetColumnIndex("Output", out int output); using (var cursor = trans.GetRowCursor(col => col == output)) @@ -501,7 +637,7 @@ public void TensorFlowTransformCifarInvalidShape() var thrown = false; try { - IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input"); + IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" }); } catch {