diff --git a/build/Dependencies.props b/build/Dependencies.props
index c1334615db..845d722820 100644
--- a/build/Dependencies.props
+++ b/build/Dependencies.props
@@ -17,5 +17,7 @@
2.9.0
4.5.0
1.2.0
+ 4.5.0
+ 4.5.0
diff --git a/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj
index fb07384f90..c762af0c75 100644
--- a/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj
+++ b/pkg/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.nupkgproj
@@ -6,7 +6,10 @@
-
+
+
+
+
diff --git a/src/Microsoft.ML.Core/Utilities/Stream.cs b/src/Microsoft.ML.Core/Utilities/Stream.cs
index 7fbf0148b9..80e18b6bf3 100644
--- a/src/Microsoft.ML.Core/Utilities/Stream.cs
+++ b/src/Microsoft.ML.Core/Utilities/Stream.cs
@@ -30,6 +30,30 @@ public static void CloseEx(this TextWriter writer)
writer.Close();
}
+ ///
+ /// Similar to Stream.CopyTo but takes a length rather than assuming copy to end. Returns amount copied.
+ ///
+ /// Source stream to copy from
+ /// Destination stream to copy to
+ /// Number of bytes to copy
+ /// Size of buffer to use when copying, default is 81920 to match that of Stream
+ /// number of bytes copied
+ public static long CopyRange(this Stream source, Stream destination, long length, int bufferSize = 81920)
+ {
+ // should use ArrayPool once we can take that dependency
+ byte[] buffer = new byte[bufferSize];
+ int read;
+ long remaining = length;
+ while (remaining != 0 &&
+ (read = source.Read(buffer, 0, (int)Math.Min(buffer.Length, remaining))) != 0)
+ {
+ destination.Write(buffer, 0, read);
+ remaining -= read;
+ }
+
+ return length - remaining;
+ }
+
public static void WriteBoolByte(this BinaryWriter writer, bool x)
{
Contracts.AssertValue(writer);
diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj
index 7c77ff2ffa..e50bed112c 100644
--- a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj
+++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp2.1
+ netcoreapp2.0
DnnAnalyzer
Microsoft.ML.TensorFlow
diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs
index ba6f0f866e..634e071f4e 100644
--- a/src/Microsoft.ML.Legacy/CSharpApi.cs
+++ b/src/Microsoft.ML.Legacy/CSharpApi.cs
@@ -15787,9 +15787,9 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.
///
- /// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
+ /// TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
///
- public string ModelFile { get; set; }
+ public string Model { get; set; }
///
/// The names of the model inputs
diff --git a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
index a0ee42b93c..9c756460bc 100644
--- a/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
+++ b/src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
@@ -7,6 +7,11 @@
true
+
+
+
+
+
diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
index e63e4f56c2..aba15ea04c 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
@@ -1182,7 +1182,7 @@ public IEnumerable ListDevices(TFStatus status = null)
/// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
///
///
- public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
+ public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
{
if (graph == null)
throw new ArgumentNullException(nameof(graph));
diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
index 54030aec91..281582e00a 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
@@ -2,15 +2,17 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Data;
-using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
-using Microsoft.ML.Runtime.Internal.Utilities;
+using System.Security.AccessControl;
+using System.Security.Principal;
namespace Microsoft.ML.Transforms.TensorFlow
{
@@ -158,6 +160,152 @@ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelByte
return new TFSession(graph);
}
+ private static TFSession LoadTFSession(IHostEnvironment env, string exportDirSavedModel)
+ {
+ Contracts.Check(env != null, nameof(env));
+ env.CheckValue(exportDirSavedModel, nameof(exportDirSavedModel));
+ var sessionOptions = new TFSessionOptions();
+ var tags = new string[] { "serve" };
+ var graph = new TFGraph();
+ var metaGraphDef = new TFBuffer();
+
+ return TFSession.FromSavedModel(sessionOptions, null, exportDirSavedModel, tags, graph, metaGraphDef);
+ }
+
+ // A TensorFlow frozen model is a single file. An un-frozen (SavedModel) on the other hand has a well-defined folder structure.
+ // Given a modelPath, this utility method determines if we should treat it as a SavedModel or not
+ internal static bool IsSavedModel(IHostEnvironment env, string modelPath)
+ {
+ Contracts.Check(env != null, nameof(env));
+ env.CheckNonWhiteSpace(modelPath, nameof(modelPath));
+ FileAttributes attr = File.GetAttributes(modelPath);
+ return attr.HasFlag(FileAttributes.Directory);
+ }
+
+ // Currently used in TensorFlowTransform to protect temporary folders used when working with TensorFlow's SavedModel format.
+ // Models are considered executable code, so we need to ACL tthe temp folders for high-rights process (so low-rights process can’t access it).
+ ///
+ /// Given a folder path, create it with proper ACL if it doesn't exist.
+ /// Fails if the folder name is empty, or can't create the folder.
+ ///
+ internal static void CreateFolderWithAclIfNotExists(IHostEnvironment env, string folder)
+ {
+ Contracts.Check(env != null, nameof(env));
+ env.CheckNonWhiteSpace(folder, nameof(folder));
+
+ //if directory exists, do nothing.
+ if (Directory.Exists(folder))
+ return;
+
+ WindowsIdentity currentIdentity = null;
+ try
+ {
+ currentIdentity = WindowsIdentity.GetCurrent();
+ }
+ catch (PlatformNotSupportedException)
+ { }
+
+ if (currentIdentity != null && new WindowsPrincipal(currentIdentity).IsInRole(WindowsBuiltInRole.Administrator))
+ {
+ // Create high integrity dir and set no delete policy for all files under the directory.
+ // In case of failure, throw exception.
+ CreateTempDirectoryWithAcl(folder, currentIdentity.User.ToString());
+ }
+ else
+ {
+ try
+ {
+ Directory.CreateDirectory(folder);
+ }
+ catch (Exception exc)
+ {
+ throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
+ }
+ }
+ }
+
+ internal static void DeleteFolderWithRetries(IHostEnvironment env, string folder)
+ {
+ Contracts.Check(env != null, nameof(env));
+ int currentRetry = 0;
+ int maxRetryCount = 10;
+ using (var ch = env.Start("Delete folder"))
+ {
+ for (; ; )
+ {
+ try
+ {
+ currentRetry++;
+ Directory.Delete(folder, true);
+ break;
+ }
+ catch (IOException e)
+ {
+ if (currentRetry > maxRetryCount)
+ throw;
+ ch.Info("Error deleting folder. {0}. Retry,", e.Message);
+ }
+ }
+ }
+ }
+
+ private static void CreateTempDirectoryWithAcl(string folder, string identity)
+ {
+ // Dacl Sddl string:
+ // D: Dacl type
+ // D; Deny access
+ // OI; Object inherit ace
+ // SD; Standard delete function
+ // wIdentity.User Sid of the given user.
+ // A; Allow access
+ // OICI; Object inherit, container inherit
+ // FA File access
+ // BA Built-in administrators
+ // S: Sacl type
+ // ML;; Mandatory Label
+ // NW;;; No write policy
+ // HI High integrity processes only
+ string sddl = "D:(D;OI;SD;;;" + identity + ")(A;OICI;FA;;;BA)S:(ML;OI;NW;;;HI)";
+
+ try
+ {
+ var dir = Directory.CreateDirectory(folder);
+ DirectorySecurity dirSec = new DirectorySecurity();
+ dirSec.SetSecurityDescriptorSddlForm(sddl);
+ dirSec.SetAccessRuleProtection(true, false); // disable inheritance
+ dir.SetAccessControl(dirSec);
+
+ // Cleaning out the directory, in case someone managed to sneak in between creation and setting ACL.
+ DirectoryInfo dirInfo = new DirectoryInfo(folder);
+ foreach (FileInfo file in dirInfo.GetFiles())
+ {
+ file.Delete();
+ }
+ foreach (DirectoryInfo subDirInfo in dirInfo.GetDirectories())
+ {
+ subDirInfo.Delete(true);
+ }
+ }
+ catch (Exception exc)
+ {
+ throw Contracts.ExceptParam(nameof(folder), $"Failed to create folder for the provided path: {folder}. \nException: {exc.Message}");
+ }
+ }
+
+ internal static TFSession GetSession(IHostEnvironment env, string modelPath)
+ {
+ Contracts.Check(env != null, nameof(env));
+ if (IsSavedModel(env, modelPath))
+ {
+ env.CheckUserArg(Directory.Exists(modelPath), nameof(modelPath));
+ return LoadTFSession(env, modelPath);
+ }
+
+ env.CheckUserArg(File.Exists(modelPath), nameof(modelPath));
+ var bytes = File.ReadAllBytes(modelPath);
+ return LoadTFSession(env, bytes, modelPath);
+ }
+
internal static unsafe void FetchData(IntPtr data, T[] result)
{
var size = result.Length;
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index acf3c4ec23..78a49fac1c 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.IO;
+using System.IO.Compression;
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Data.StaticPipe.Runtime;
@@ -38,9 +39,8 @@ public sealed class TensorFlowTransform : ITransformer, ICanSaveModel
{
public sealed class Arguments : TransformInputBase
{
-
- [Argument(ArgumentType.Required, HelpText = "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", ShortName = "model", SortOrder = 0)]
- public string ModelFile;
+ [Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)]
+ public string Model;
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)]
public string[] InputColumns;
@@ -50,6 +50,8 @@ public sealed class Arguments : TransformInputBase
}
private readonly IHost _host;
+ private readonly string _savedModelPath;
+ private readonly bool _isTemporarySavedModel;
private const string RegistrationName = "TensorFlowTransform";
internal readonly TFSession Session;
@@ -73,7 +75,7 @@ private static VersionInfo GetVersionInfo()
return new VersionInfo(
modelSignature: "TENSFLOW",
//verWrittenCur: 0x00010001, // Initial
- verWrittenCur: 0x00010002, // Upgraded when change for multiple outputs was implemented.
+ verWrittenCur: 0x00010002, // Added Support for Multiple Outputs and SavedModel.
verReadableCur: 0x00010002,
verWeCanReadBack: 0x00010001,
loaderSignature: LoaderSignature);
@@ -84,25 +86,12 @@ private static VersionInfo GetVersionInfo()
///
/// Host Environment.
/// Input . This is the output from previous transform or loader.
- /// This is the frozen TensorFlow model file. https://www.tensorflow.org/mobile/prepare_models
- /// Name of the output column. Keep it same as in the TensorFlow model.
- /// Name of the input column(s). Keep it same as in the TensorFlow model.
- public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source)
- {
- return new TensorFlowTransform(env, modelFile, source, new[] { name }).MakeDataTransform(input);
- }
-
- ///
- /// Convenience constructor for public facing API.
- ///
- /// Host Environment.
- /// Input . This is the output from previous transform or loader.
- /// This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models
+ /// Path to the TensorFlow model.
/// Name of the output column(s). Keep it same as in the Tensorflow model.
/// Name of the input column(s). Keep it same as in the Tensorflow model.
- public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] names, string[] source)
+ public static IDataTransform Create(IHostEnvironment env, IDataView input, string model, string[] names, string[] source)
{
- return new TensorFlowTransform(env, modelFile, source, names).MakeDataTransform(input);
+ return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, model), source, names, TensorFlowUtils.IsSavedModel(env, model) ? model : null, false).MakeDataTransform(input);
}
// Factory method for SignatureLoadModel.
@@ -111,7 +100,9 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
+
// *** Binary format ***
+ // byte: indicator for frozen models
// stream: tensorFlow model.
// int: number of input columns
// for each input column
@@ -119,27 +110,48 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext
// int: number of output columns
// for each output column
// int: id of output column name
- byte[] modelBytes = null;
- if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
- throw env.ExceptDecode();
- var session = TensorFlowUtils.LoadTFSession(env, modelBytes);
- var numInputs = ctx.Reader.ReadInt32();
- env.CheckDecode(numInputs > 0);
- string[] inputs = new string[numInputs];
- for (int j = 0; j < inputs.Length; j++)
- inputs[j] = ctx.LoadNonEmptyString();
-
- bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002;
- var numOutputs = 1;
- if (isMultiOutput)
- numOutputs = ctx.Reader.ReadInt32();
+ GetModelInfo(env, ctx, out string[] inputs, out string[] outputs, out bool isFrozen);
+ if (isFrozen)
+ {
+ byte[] modelBytes = null;
+ if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
+ throw env.ExceptDecode();
+ return new TensorFlowTransform(env, TensorFlowUtils.LoadTFSession(env, modelBytes), inputs, outputs, null, false);
+ }
- env.CheckDecode(numOutputs > 0);
- var outputs = new string[numOutputs];
- for (int j = 0; j < outputs.Length; j++)
- outputs[j] = ctx.LoadNonEmptyString();
+ var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), RegistrationName + "_" + Guid.NewGuid()));
+ TensorFlowUtils.CreateFolderWithAclIfNotExists(env, tempDirPath);
+ try
+ {
+ var load = ctx.TryLoadBinaryStream("TFSavedModel", br =>
+ {
+ int count = br.ReadInt32();
+ for (int n = 0; n < count; n++)
+ {
+ string relativeFile = br.ReadString();
+ long fileLength = br.ReadInt64();
+
+ string fullFilePath = Path.Combine(tempDirPath, relativeFile);
+ string fullFileDir = Path.GetDirectoryName(fullFilePath);
+ if (fullFileDir != tempDirPath)
+ {
+ TensorFlowUtils.CreateFolderWithAclIfNotExists(env, fullFileDir);
+ }
+ using (var fs = new FileStream(fullFilePath, FileMode.Create, FileAccess.Write))
+ {
+ long actualRead = br.BaseStream.CopyRange(fs, fileLength);
+ env.Assert(actualRead == fileLength);
+ }
+ }
+ });
- return new TensorFlowTransform(env, session, inputs, outputs);
+ return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, tempDirPath), inputs, outputs, tempDirPath, true);
+ }
+ catch (Exception)
+ {
+ TensorFlowUtils.DeleteFolderWithRetries(env, tempDirPath);
+ throw;
+ }
}
// Factory method for SignatureDataTransform.
@@ -150,7 +162,8 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData
env.CheckValue(input, nameof(input));
env.CheckValue(args.InputColumns, nameof(args.InputColumns));
env.CheckValue(args.OutputColumns, nameof(args.OutputColumns));
- return new TensorFlowTransform(env, args.ModelFile, args.InputColumns, args.OutputColumns).MakeDataTransform(input);
+
+ return new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, args.Model), args.InputColumns, args.OutputColumns, TensorFlowUtils.IsSavedModel(env, args.Model) ? args.Model : null, false).MakeDataTransform(input);
}
// Factory method for SignatureLoadDataTransform.
@@ -161,27 +174,42 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile)
+ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out string[] inputs, out string[] outputs, out bool isFrozen)
{
- env.CheckNonWhiteSpace(modelFile, nameof(modelFile));
- env.CheckUserArg(File.Exists(modelFile), nameof(modelFile));
- var bytes = File.ReadAllBytes(modelFile);
- return TensorFlowUtils.LoadTFSession(env, bytes, modelFile);
- }
+ isFrozen = true;
+ bool isNonFrozenModelSupported = ctx.Header.ModelVerReadable >= 0x00010002;
+ if (isNonFrozenModelSupported)
+ isFrozen = ctx.Reader.ReadBoolByte();
- public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) :
- this(env, CheckFileAndRead(env, modelFile), inputs, outputs)
- {
+ var numInputs = ctx.Reader.ReadInt32();
+ env.CheckDecode(numInputs > 0);
+ inputs = new string[numInputs];
+ for (int j = 0; j < inputs.Length; j++)
+ inputs[j] = ctx.LoadNonEmptyString();
+
+ bool isMultiOutput = ctx.Header.ModelVerReadable >= 0x00010002;
+ var numOutputs = 1;
+ if (isMultiOutput)
+ numOutputs = ctx.Reader.ReadInt32();
+
+ env.CheckDecode(numOutputs > 0);
+ outputs = new string[numOutputs];
+ for (int j = 0; j < outputs.Length; j++)
+ outputs[j] = ctx.LoadNonEmptyString();
}
- private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs)
+ internal TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string savedModelPath, bool isTemporarySavedModel)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(RegistrationName));
_host.CheckValue(session, nameof(session));
_host.CheckNonEmpty(inputs, nameof(inputs));
_host.CheckNonEmpty(outputs, nameof(outputs));
+
Session = session;
+ _savedModelPath = savedModelPath;
+ _isTemporarySavedModel = isTemporarySavedModel;
+
foreach (var input in inputs)
{
_host.CheckNonWhiteSpace(input, nameof(inputs));
@@ -261,7 +289,9 @@ public void Save(ModelSaveContext ctx)
_host.AssertValue(ctx);
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
+
// *** Binary format ***
+ // byte: indicator for frozen models
// stream: tensorFlow model.
// int: number of input columns
// for each input column
@@ -269,14 +299,39 @@ public void Save(ModelSaveContext ctx)
// int: number of output columns
// for each output column
// int: id of output column name
-
- var buffer = new TFBuffer();
- Session.Graph.ToGraphDef(buffer);
-
- ctx.SaveBinaryStream("TFModel", w =>
+ var isFrozen = string.IsNullOrEmpty(_savedModelPath);
+ ctx.Writer.WriteBoolByte(isFrozen);
+ if (isFrozen)
{
- w.WriteByteArray(buffer.ToArray());
- });
+ var buffer = new TFBuffer();
+ Session.Graph.ToGraphDef(buffer);
+ ctx.SaveBinaryStream("TFModel", w =>
+ {
+ w.WriteByteArray(buffer.ToArray());
+ });
+ }
+ else
+ {
+ ctx.SaveBinaryStream("TFSavedModel", w =>
+ {
+ string[] modelFilePaths = Directory.GetFiles(_savedModelPath, "*", SearchOption.AllDirectories);
+ w.Write(modelFilePaths.Length);
+
+ foreach (var fullPath in modelFilePaths)
+ {
+ var relativePath = fullPath.Substring(_savedModelPath.Length + 1);
+ w.Write(relativePath);
+
+ using (var fs = new FileStream(fullPath, FileMode.Open))
+ {
+ long fileLength = fs.Length;
+ w.Write(fileLength);
+ long actualWritten = fs.CopyRange(w.BaseStream, fileLength);
+ _host.Assert(actualWritten == fileLength);
+ }
+ }
+ });
+ }
_host.AssertNonEmpty(Inputs);
ctx.Writer.Write(Inputs.Length);
foreach (var colName in Inputs)
@@ -288,6 +343,33 @@ public void Save(ModelSaveContext ctx)
ctx.SaveNonEmptyString(colName);
}
+ ~TensorFlowTransform()
+ {
+ Dispose(false);
+ }
+
+ private void Dispose(bool disposing)
+ {
+ // Ensure that the Session is not null and it's handle is not Zero, as it may have already been disposed/finalized.
+ // Technically we shouldn't be calling this if disposing == false, since we're running in finalizer
+ // and the GC doesn't guarantee ordering of finalization of managed objects, but we have to make sure
+ // that the Session is closed before deleting our temporary directory.
+ try
+ {
+ if (Session?.Handle != IntPtr.Zero)
+ {
+ Session.CloseSession();
+ Session.Dispose();
+ }
+ }
+ finally
+ {
+ if (!string.IsNullOrEmpty(_savedModelPath) && _isTemporarySavedModel)
+ {
+ TensorFlowUtils.DeleteFolderWithRetries(_host, _savedModelPath);
+ }
+ }
+ }
public bool IsRowToRowMapper => true;
public IRowToRowMapper GetRowToRowMapper(ISchema inputSchema)
@@ -322,6 +404,9 @@ public Mapper(IHostEnvironment env, TensorFlowTransform parent, ISchema inputSch
throw _host.Except($"Column {_parent.Inputs[i]} doesn't exist");
var type = inputSchema.GetColumnType(_inputColIndices[i]);
+ if (type.IsVector && type.VectorSize == 0)
+ throw _host.Except($"Variable length input columns not supported");
+
_isInputVector[i] = type.IsVector;
var expectedType = TensorFlowUtils.Tf2MlNetType(_parent.TFInputTypes[i]);
if (type.ItemType != expectedType)
@@ -564,8 +649,8 @@ public static CommonOutputs.TransformOutput TensorFlowScorer(IHostEnvironment en
public sealed class TensorFlowEstimator : TrivialEstimator
{
- public TensorFlowEstimator(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs)
- : this(env, new TensorFlowTransform(env, modelFile, inputs, outputs))
+ public TensorFlowEstimator(IHostEnvironment env, string model, string[] inputs, string[] outputs)
+ : this(env, new TensorFlowTransform(env, TensorFlowUtils.GetSession(env, model), inputs, outputs, TensorFlowUtils.IsSavedModel(env, model) ? model : null, false))
{
}
@@ -584,7 +669,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
var input = Transformer.Inputs[i];
if (!inputSchema.TryFindColumn(input, out var col))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input);
- if (!(col.Kind == SchemaShape.Column.VectorKind.VariableVector || col.Kind == SchemaShape.Column.VectorKind.Vector))
+ if (!(col.Kind == SchemaShape.Column.VectorKind.Vector))
throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, nameof(VectorType), col.GetTypeString());
var expectedType = TensorFlowUtils.Tf2MlNetType(Transformer.TFInputTypes[i]);
if (col.ItemType != expectedType)
@@ -602,7 +687,6 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema)
public static class TensorFlowStaticExtensions
{
-
private sealed class OutColumn : Vector
{
public PipelineColumn Input { get; }
diff --git a/test/BaselineOutput/Common/EntryPoints/core_manifest.json b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
index 1ff41a090b..03f8cb170f 100644
--- a/test/BaselineOutput/Common/EntryPoints/core_manifest.json
+++ b/test/BaselineOutput/Common/EntryPoints/core_manifest.json
@@ -21729,12 +21729,9 @@
"ShortName": "TFTransform",
"Inputs": [
{
- "Name": "ModelFile",
+ "Name": "Model",
"Type": "String",
- "Desc": "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.",
- "Aliases": [
- "model"
- ],
+ "Desc": "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.",
"Required": true,
"SortOrder": 0.0,
"IsNullable": false
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
index b3d7068041..2e5292d13f 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
@@ -977,9 +977,9 @@ public void TestTensorFlowEntryPoint()
var tfTransformInput = new Legacy.Transforms.TensorFlowScorer
{
Data = importOutput.Data,
+ Model = "mnist_model/frozen_saved_model.pb",
InputColumns = new[] { "Placeholder" },
OutputColumns = new[] { "Softmax" },
- ModelFile = "mnist_model/frozen_saved_model.pb"
};
var tfTransformOutput = experiment.Add(tfTransformInput);
diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
index d674d99c61..b1e86c70a2 100644
--- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
+++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs
@@ -3733,7 +3733,7 @@ public void EntryPointTensorFlowTransform()
new[]
{
@"'InputColumns': [ 'Placeholder' ],
- 'ModelFile': 'mnist_model/frozen_saved_model.pb',
+ 'Model': 'mnist_model/frozen_saved_model.pb',
'OutputColumns': [ 'Softmax' ]"
});
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs
index 2c25745100..cfeed15bd5 100644
--- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs
@@ -36,7 +36,6 @@ public async void TrainSaveModelAndPredict()
var loadedModel = await Legacy.PredictionModel.ReadAsync(modelName);
var singlePrediction = loadedModel.Predict(new SentimentData() { SentimentText = "Not big fan of this." });
Assert.True(singlePrediction.Sentiment);
-
}
}
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
index 6a03f114b8..4685999813 100644
--- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs
@@ -44,7 +44,7 @@ public void TensorFlowTransformCifarLearningPipelineTest()
pipeline.Add(new TensorFlowScorer()
{
- ModelFile = model_location,
+ Model = model_location,
InputColumns = new[] { "Input" },
OutputColumns = new[] { "Output" }
});
@@ -121,7 +121,7 @@ public void TensorFlowTransformInceptionPipelineTest()
pipeline.Add(new TensorFlowScorer()
{
- ModelFile = model_location,
+ Model = model_location,
InputColumns = new[] { inputTensorName },
OutputColumns = new[] { outputTensorName }
});
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index 6ee86513bc..dfca4ebf6a 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -44,7 +44,7 @@ public void TensorFlowTransformMatrixMultiplicationTest()
b = new[] { 3.0f, 3.0f,
3.0f, 3.0f } } }));
- var trans = TensorFlowTransform.Create(env, loader, model_location, "c", "a", "b");
+ var trans = TensorFlowTransform.Create(env, loader, model_location, new[] { "c" }, new[] { "a", "b" });
using (var cursor = trans.GetRowCursor(a => true))
{
@@ -163,7 +163,7 @@ public void TensorFlowTransformInceptionTest()
}
}, cropped);
- var tf = TensorFlowTransform.Create(env, pixels, model_location, "softmax2_pre_activation", "input");
+ var tf = TensorFlowTransform.Create(env, pixels, model_location, new[] { "softmax2_pre_activation" }, new[] { "input" });
tf.Schema.TryGetColumnIndex("input", out int input);
tf.Schema.TryGetColumnIndex("softmax2_pre_activation", out int b);
@@ -353,6 +353,88 @@ public void TensorFlowTransformMNISTConvTest()
}
}
+ [Fact]
+ public void TensorFlowTransformMNISTConvSavedModelTest()
+ {
+ var model_location = "mnist_model";
+ using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
+ {
+ var dataPath = GetDataPath("Train-Tiny-28x28.txt");
+ var testDataPath = GetDataPath("MNIST.Test.tiny.txt");
+
+ // Pipeline
+ var loader = TextLoader.ReadFile(env,
+ new TextLoader.Arguments()
+ {
+ Separator = "tab",
+ HasHeader = true,
+ Column = new[]
+ {
+ new TextLoader.Column("Label", DataKind.Num,0),
+ new TextLoader.Column("Placeholder", DataKind.Num,new []{new TextLoader.Range(1, 784) })
+
+ }
+ }, new MultiFileSource(dataPath));
+
+ IDataView trans = CopyColumnsTransform.Create(env, new CopyColumnsTransform.Arguments()
+ {
+ Column = new[] { new CopyColumnsTransform.Column()
+ { Name = "reshape_input", Source = "Placeholder" }
+ }
+ }, loader);
+ trans = TensorFlowTransform.Create(env, trans, model_location, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" });
+ trans = new ConcatTransform(env, "Features", "Softmax", "dense/Relu").Transform(trans);
+
+ var trainer = new LightGbmMulticlassTrainer(env, "Label", "Features");
+
+ var cached = new CacheDataView(env, trans, prefetch: null);
+ var trainRoles = new RoleMappedData(cached, label: "Label", feature: "Features");
+ var pred = trainer.Train(trainRoles);
+
+ // Get scorer and evaluate the predictions from test data
+ IDataScorerTransform testDataScorer = GetScorer(env, trans, pred, testDataPath);
+ var metrics = Evaluate(env, testDataScorer);
+
+ Assert.Equal(0.99, metrics.AccuracyMicro, 2);
+ Assert.Equal(1.0, metrics.AccuracyMacro, 2);
+
+ // Create prediction engine and test predictions
+ var model = env.CreatePredictionEngine(testDataScorer);
+
+ var sample1 = new MNISTData()
+ {
+ Placeholder = new float[] { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 18, 18, 18, 126, 136, 175, 26,
+ 166, 255, 247, 127, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 36, 94, 154, 170, 253, 253, 253, 253, 253,
+ 225, 172, 253, 242, 195, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 49, 238, 253, 253, 253, 253, 253, 253, 253,
+ 253, 251, 93, 82, 82, 56, 39, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 219, 253, 253, 253, 253, 253, 198,
+ 182, 247, 241, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 80, 156, 107, 253, 253, 205, 11, 0,
+ 43, 154, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 14, 1, 154, 253, 90, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 139, 253, 190, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 11, 190, 253, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 35, 241, 225, 160, 108, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 81, 240, 253, 253, 119, 25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 45, 186, 253, 253, 150, 27, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 16, 93, 252, 253, 187, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 249, 253, 249, 64, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 46, 130, 183, 253, 253, 207, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 148, 229, 253, 253, 253, 250, 182, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 24, 114, 221, 253, 253, 253, 253, 201, 78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 23, 66, 213, 253, 253, 253, 253, 198, 81, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 18, 171, 219, 253, 253, 253, 253, 195, 80, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55, 172, 226, 253, 253, 253, 253, 244, 133, 11, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 136, 253, 253, 253, 212, 135, 132, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }
+ };
+
+ var prediction = model.Predict(sample1);
+
+ float max = -1;
+ int maxIndex = -1;
+ for (int i = 0; i < prediction.PredictedLabels.Length; i++)
+ {
+ if (prediction.PredictedLabels[i] > max)
+ {
+ max = prediction.PredictedLabels[i];
+ maxIndex = i;
+ }
+ }
+
+ Assert.Equal(5, maxIndex);
+ }
+ }
+
[Fact]
public void TensorFlowTransformMNISTConvPipelineTest()
{
@@ -364,7 +446,7 @@ public void TensorFlowTransformMNISTConvPipelineTest()
pipeline.Add(new Legacy.Transforms.ColumnCopier() { Column = new[] { new CopyColumnsTransformColumn() { Name = "reshape_input", Source = "Placeholder" } } });
pipeline.Add(new TensorFlowScorer()
{
- ModelFile = model_location,
+ Model = model_location,
OutputColumns = new[] { "Softmax", "dense/Relu" },
InputColumns = new[] { "Placeholder", "reshape_input" }
});
@@ -444,7 +526,61 @@ public void TensorFlowTransformCifar()
}, cropped);
- IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input");
+ IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" });
+
+ trans.Schema.TryGetColumnIndex("Output", out int output);
+ using (var cursor = trans.GetRowCursor(col => col == output))
+ {
+ var buffer = default(VBuffer);
+ var getter = cursor.GetGetter>(output);
+ var numRows = 0;
+ while (cursor.MoveNext())
+ {
+ getter(ref buffer);
+ Assert.Equal(10, buffer.Length);
+ numRows += 1;
+ }
+ Assert.Equal(3, numRows);
+ }
+ }
+ }
+
+ [Fact]
+ public void TensorFlowTransformCifarSavedModel()
+ {
+ var model_location = "cifar_saved_model";
+
+ using (var env = new ConsoleEnvironment())
+ {
+ var imageHeight = 32;
+ var imageWidth = 32;
+ var dataFile = GetDataPath("images/images.tsv");
+ var imageFolder = Path.GetDirectoryName(dataFile);
+ var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
+ var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
+ {
+ Column = new ImageLoaderTransform.Column[1]
+ {
+ new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
+ },
+ ImageFolder = imageFolder
+ }, data);
+ var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
+ {
+ Column = new ImageResizerTransform.Column[1]{
+ new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
+ }
+ }, images);
+
+ var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
+ {
+ Column = new ImagePixelExtractorTransform.Column[1]{
+ new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true}
+ }
+ }, cropped);
+
+
+ IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" });
trans.Schema.TryGetColumnIndex("Output", out int output);
using (var cursor = trans.GetRowCursor(col => col == output))
@@ -501,7 +637,7 @@ public void TensorFlowTransformCifarInvalidShape()
var thrown = false;
try
{
- IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input");
+ IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, new[] { "Output" }, new[] { "Input" });
}
catch
{