-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Enable TensorFlowTransform to work with pre-trained models that are not frozen #853
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
1007955
40fbedc
48d14c6
35ff43a
6291d0d
57508d3
cfcd70f
236de73
781cff0
07b15a0
47f75b5
c304257
d0430b5
950a210
97eb497
173729f
292140b
655a8aa
be5285a
46c04a3
e705f93
84214e2
04d02b8
89693bd
8c8d92e
d8edc64
eea524e
b609ffd
74b8899
3382a83
aa8e844
25b1e64
e32acca
ac45539
ce4efef
8b8764b
f955488
6e11f2c
ed71513
7df343d
f883d78
ae672d6
fac8dae
2b1a576
21879f6
a1d912d
f6a1c84
5957c53
5120bb9
a624a3b
8d9fdc5
685ab99
2d0ec1e
8d8b986
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -5,6 +5,7 @@ | |||
using System; | ||||
using System.Collections.Generic; | ||||
using System.IO; | ||||
using System.IO.Compression; | ||||
using System.Linq; | ||||
using Microsoft.ML.Core.Data; | ||||
using Microsoft.ML.Runtime; | ||||
|
@@ -36,21 +37,25 @@ public sealed class TensorFlowTransform : ITransformer, ICanSaveModel | |||
{ | ||||
public sealed class Arguments : TransformInputBase | ||||
{ | ||||
[Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)] | ||||
public string Model; | ||||
|
||||
[Argument(ArgumentType.Required, HelpText = "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", ShortName = "model", SortOrder = 0)] | ||||
public string ModelFile; | ||||
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicator for frozen models", ShortName = "frozen", SortOrder = 1)] | ||||
public bool IsFrozen = TensorFlowEstimator.Defaults.IsFrozen; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Is there a way to figure this out from the Model argument? For example, check if it is a directory or a file, or try loading it in a try/catch? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Added a utility method TensorFlowUtils.IsFrozenTensorFlowModel() where we do this check for frozen / un-frozen based on whether model is provided as a single file (Frozen model) or directory (SavedModel) In reply to: 217436951 [](ancestors = 217436951) |
||||
|
||||
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)] | ||||
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 2)] | ||||
public string[] InputColumns; | ||||
|
||||
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)] | ||||
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 3)] | ||||
public string[] OutputColumns; | ||||
} | ||||
|
||||
private readonly IHost _host; | ||||
private const string RegistrationName = "TensorFlowTransform"; | ||||
|
||||
internal readonly TFSession Session; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Do all these fields need to be internal? #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I refactored this, so we have 2 private fields (_savedModelPath and _isTemporarySavedModel) In reply to: 219205265 [](ancestors = 219205265) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the other ones can be private too, it looks like they are only used by TensorFlowTransform and TensorFlowTransform.Mapper, which should be able to access TensorFlowTransform's private fields. In reply to: 219275115 [](ancestors = 219275115,219205265) |
||||
internal readonly bool IsFrozen; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You don't need them. All you care about is session, and you already have it. This is just details of how session been loaded. #Resolved |
||||
internal readonly string ExportDir; | ||||
internal readonly ColumnType[] OutputTypes; | ||||
internal readonly TFDataType[] TFOutputTypes; | ||||
internal readonly TFDataType[] TFInputTypes; | ||||
|
@@ -82,44 +87,17 @@ private static VersionInfo GetVersionInfo() | |||
/// </summary> | ||||
/// <param name="env">Host Environment.</param> | ||||
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param> | ||||
/// <param name="modelFile">This is the frozen TensorFlow model file. https://www.tensorflow.org/mobile/prepare_models </param> | ||||
/// <param name="name">Name of the output column. Keep it same as in the TensorFlow model.</param> | ||||
/// <param name="source">Name of the input column(s). Keep it same as in the TensorFlow model.</param> | ||||
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source) | ||||
{ | ||||
return new TensorFlowTransform(env, modelFile, source, new[] { name }).MakeDataTransform(input); | ||||
} | ||||
|
||||
/// <summary> | ||||
/// Convenience constructor for public facing API. | ||||
/// </summary> | ||||
/// <param name="env">Host Environment.</param> | ||||
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param> | ||||
/// <param name="modelFile">This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models </param> | ||||
/// <param name="model">This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models </param> | ||||
/// <param name="names">Name of the output column(s). Keep it same as in the Tensorflow model.</param> | ||||
/// <param name="source">Name of the input column(s). Keep it same as in the Tensorflow model.</param> | ||||
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] names, string[] source) | ||||
/// <param name="isFrozen">Indicator for frozen models</param> | ||||
public static IDataTransform Create(IHostEnvironment env, IDataView input, string model, string[] names, string[] source, bool isFrozen = TensorFlowEstimator.Defaults.IsFrozen) | ||||
{ | ||||
return new TensorFlowTransform(env, modelFile, source, names).MakeDataTransform(input); | ||||
return new TensorFlowTransform(env, model, source, names, isFrozen).MakeDataTransform(input); | ||||
} | ||||
|
||||
// Factory method for SignatureLoadModel. | ||||
public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) | ||||
private static Tuple<string[], string[]> ModelInputsOutputs(IHostEnvironment env, ModelLoadContext ctx) | ||||
{ | ||||
Contracts.CheckValue(env, nameof(env)); | ||||
env.CheckValue(ctx, nameof(ctx)); | ||||
ctx.CheckAtModel(GetVersionInfo()); | ||||
// *** Binary format *** | ||||
// stream: tensorFlow model. | ||||
// int: number of input columns | ||||
// for each input column | ||||
// int: id of int column name | ||||
// int: number of output columns | ||||
// for each output column | ||||
// int: id of output column name | ||||
byte[] modelBytes = null; | ||||
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) | ||||
throw env.ExceptDecode(); | ||||
var numInputs = ctx.Reader.ReadInt32(); | ||||
env.CheckDecode(numInputs > 0); | ||||
string[] inputs = new string[numInputs]; | ||||
|
@@ -136,7 +114,49 @@ public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext | |||
for (int j = 0; j < outputs.Length; j++) | ||||
outputs[j] = ctx.LoadNonEmptyString(); | ||||
|
||||
return new TensorFlowTransform(env, modelBytes, inputs, outputs); | ||||
return new Tuple<string[], string[]>(inputs, outputs); | ||||
} | ||||
|
||||
// Factory method for SignatureLoadModel. | ||||
public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx) | ||||
{ | ||||
Contracts.CheckValue(env, nameof(env)); | ||||
env.CheckValue(ctx, nameof(ctx)); | ||||
ctx.CheckAtModel(GetVersionInfo()); | ||||
|
||||
// *** Binary format *** | ||||
// int: indicator for frozen models | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Change to byte. #Closed |
||||
// stream: tensorFlow model. | ||||
// int: number of input columns | ||||
// for each input column | ||||
// int: id of int column name | ||||
// int: number of output columns | ||||
// for each output column | ||||
// int: id of output column name | ||||
var isFrozen = ctx.Reader.ReadInt32(); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Need to increment the version number since the serialized format has changed. We can still read the older version by setting isFrozen=1 without reading anything. #Resolved |
||||
if (isFrozen == 1) | ||||
{ | ||||
byte[] modelBytes = null; | ||||
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) | ||||
throw env.ExceptDecode(); | ||||
|
||||
var io = ModelInputsOutputs(env, ctx); | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This line can be moved before the if. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we move it before the if (line127) wont it break the order in which it was saved ? (The model inputs/outputs come only after we are done reading the stream contents) In reply to: 217559827 [](ancestors = 217559827) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the model is saved in a separate stream, I think it doesn't matter. The order is important only for what ends up in the Model.key file. In reply to: 217571319 [](ancestors = 217571319,217559827) |
||||
return new TensorFlowTransform(env, modelBytes, (isFrozen == 1), io.Item1, io.Item2); | ||||
} | ||||
else | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Since you are returning inside the "if", you don't need to have "else" here. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I think it would be better to have just one constructor for TensorFlowTransform, that takes the session, inputs and outputs, and the saved model dictionary as parameters. I think the isFrozen parameter is also not needed, since it can be inferred from the fact that the dictionary is null. In reply to: 217844206 [](ancestors = 217844206) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i have removed the "else" condition. also we are no longer using the dictionary approach to a stream based approach. we can discuss further if we can simplify the code to just have one constructor In reply to: 217845821 [](ancestors = 217845821,217844206) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||||
{ | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
[observation] The Train() method of the LearningPipeline API ends up calling Save() and Create() multiple times. For un-frozen models we need a particular directory structure so we can create the TFSession correctly using the TFSession.FromSavedModel API (line 215) In this call to Create() we do the following : We do not delete the temporary folder because in a subsequent call to Save() we are writing out the contents of the temporary folder as a byte stream. So for this scenario, what we have currently is that we define a unique location for the temporary folder -- and it gets cleaned up only when required, inside the Create() call. Line 139. Is there some way to ensure the temporary folder is always cleaned up ? Or perhaps some way to not have a temporary folder at all. @ericstj @zeahmed @yaeldekel #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should avoid the temporary zip file. If you can get a seekable stream from ctx, just pass that into the ZipArchive constructor: https://msdn.microsoft.com/es-es/library/system.io.compression.ziparchive.ziparchive(v=vs.110).aspx. It looks like you can do this by accessing the BaseStream on the binary reader passed into the TryLoadBinaryStream action. You should also avoid allocating an extra byte array for the zip and instead do all the reading within the callback from TryLoadBinaryStream, since it looks like TryLoadBinaryStream will close the stream:
I think you still need to have a temporary directory for the model files themselves. It looks like this is a requirement of TensorFlow https://github.com/tensorflow/tensorflow/blob/3be04971716fcaf0c11ad9262e60efa428553e14/tensorflow/c/c_api.h#L1367. It requires the model to be in a directory and doesn't provide a mechanism for loading from some sort of stream/memory representation. So you'll still need a temp path that you'll have to clean up. Ideally you should provide a way for this path to be specified, rather than always call GetTempPath behind the scenes. I think ML.NET has a pattern for this, if not, perhaps consider a constructor overload that allows for specification of a temp path root. You may also consider some way to say "don't clean up the temp path". Typically with things like this you should treat the temporary folder as state managed by the object and then cleaned up in the dispose/finalizer. In reply to: 217123233 [](ancestors = 217123233) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried both of the suggestions above, but they seem to have some limitations with respect to this transform. I have since modified the logic of load&restore to do things in-memory. Am documenting the issues hit with the above suggestion: (1) the ZipArchive approach looked promising to avoid the temporary zip file. However it works on the entire stream, rather than on a subset of the stream. I have modified the logic so we do the ML.NET serialization/de-serialization using in-memory data structures. In reply to: 217145506 [](ancestors = 217145506,217123233) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should file an issue on #2. I think this approach looks much better than zip-in-a-zip approach before. One change I'd suggest is that instead of reading to byte array and storing in memory, you create the file within the callback (var fs = new FileStream(...)) and use br.BaseStream.CopyTo(fs, fileLength); To do this you'll need to replace readbytearray/writebytearray with a your own calls that store file length, followed by the stream copy. One thing I notice that could be problematic for the frozen case above (not related to this change) is that it seems to use Graph.Import to import a entire model as a managed byte array. I can imagine that this will potentially hit the upper limit of managed array sizes for large models. I see that TF's graph import functionality expects a single buffer (https://github.com/tensorflow/tensorflow/blob/3be04971716fcaf0c11ad9262e60efa428553e14/tensorflow/c/c_api.h#L1018-L1020). In reply to: 217447210 [](ancestors = 217447210,217145506,217123233) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main reason for reading to byte array and storing in memory is to handle the scenario where we saw the sequence of Save1() -> Create1() -> Save2() -> Create2() being called When Create1() is called, we load up the contents into a dictionary. So when Save2() is called we use the dictionary to write out the contents. Because of issue #906 (not being able to clean-up unmanaged resources ) I am not sure we can do away with the in-memory approach. Thoughts ? In reply to: 217462971 [](ancestors = 217462971,217447210,217145506,217123233) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adopted the stream copy based approach instead of reading to byte array also added a finalizer that closes the session (if it isn’t closed) and deletes the temporary directory In reply to: 217545311 [](ancestors = 217545311,217462971,217447210,217145506,217123233) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eerhardt , @danmosemsft , would this be a security risk? To load the TensorFlow model here, we create a temporary directory and copy some files to it, and then load the model from that directory. #Resolved There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. followed the recommendation of the .NET security team to address this risk In reply to: 218159541 [](ancestors = 218159541) |
||||
// Load model binary | ||||
byte[] tfFilesBin = null; | ||||
var load = ctx.TryLoadBinaryStream("TFSavedModel", br => tfFilesBin = br.ReadByteArray()); | ||||
var tempDirName = Path.GetFullPath(Path.Combine(Path.GetTempPath(), "_MLNET_TFTransform_" + Guid.NewGuid())); | ||||
var tempDir = Directory.CreateDirectory(tempDirName); | ||||
var tfZipFilePath = Path.Combine(tempDir.FullName, "tf_savedmodel.zip"); | ||||
File.WriteAllBytes(tfZipFilePath, tfFilesBin); | ||||
ZipFile.ExtractToDirectory(tfZipFilePath, Path.Combine(tempDir.FullName, "tf_savedmodel")); | ||||
|
||||
var io = ModelInputsOutputs(env, ctx); | ||||
return new TensorFlowTransform(env, Path.Combine(tempDir.FullName, "tf_savedmodel"), io.Item1, io.Item2, (isFrozen == 1)); | ||||
} | ||||
} | ||||
|
||||
// Factory method for SignatureDataTransform. | ||||
|
@@ -147,7 +167,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV | |||
env.CheckValue(input, nameof(input)); | ||||
env.CheckValue(args.InputColumns, nameof(args.InputColumns)); | ||||
env.CheckValue(args.OutputColumns, nameof(args.OutputColumns)); | ||||
return new TensorFlowTransform(env, args.ModelFile, args.InputColumns, args.OutputColumns).MakeDataTransform(input); | ||||
return new TensorFlowTransform(env, args.Model, args.InputColumns,args.OutputColumns, args.IsFrozen).MakeDataTransform(input); | ||||
} | ||||
|
||||
// Factory method for SignatureLoadDataTransform. | ||||
|
@@ -158,7 +178,7 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, | |||
public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) | ||||
=> Create(env, ctx).MakeRowMapper(inputSchema); | ||||
|
||||
private TFSession LoadTFSession(byte[] modelBytes) | ||||
private static TFSession LoadTFSession(IHostEnvironment env, byte[] modelBytes) | ||||
{ | ||||
var graph = new TFGraph(); | ||||
try | ||||
|
@@ -168,38 +188,68 @@ private TFSession LoadTFSession(byte[] modelBytes) | |||
catch (Exception ex) | ||||
{ | ||||
#pragma warning disable MSML_NoMessagesForLoadContext | ||||
throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); | ||||
throw env.ExceptDecode(ex, "Tensorflow exception triggered while loading model."); | ||||
#pragma warning restore MSML_NoMessagesForLoadContext | ||||
} | ||||
return new TFSession(graph); | ||||
} | ||||
|
||||
private static TFSession LoadTFSession(string exportDirSavedModel) | ||||
{ | ||||
var sessionOptions = new TFSessionOptions(); | ||||
var exportDir = exportDirSavedModel; | ||||
var tags = new string[] { "serve" }; | ||||
var graph = new TFGraph(); | ||||
var metaGraphDef = new TFBuffer(); | ||||
|
||||
var session = TFSession.FromSavedModel(sessionOptions, null, exportDir, tags, graph, metaGraphDef); | ||||
return session; | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
nit: you can directly return from the line above. #Resolved |
||||
} | ||||
|
||||
private static TFSession GetSession(IHostEnvironment env, string model, bool isFrozen) | ||||
{ | ||||
if (isFrozen) | ||||
{ | ||||
byte[] modelBytes = CheckFileAndRead(env, model); | ||||
return LoadTFSession(env, modelBytes); | ||||
} | ||||
else | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The "else" is not needed here. #Resolved |
||||
{ | ||||
return LoadTFSession(model); | ||||
} | ||||
} | ||||
|
||||
private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile) | ||||
{ | ||||
env.CheckNonWhiteSpace(modelFile, nameof(modelFile)); | ||||
env.CheckUserArg(File.Exists(modelFile), nameof(modelFile)); | ||||
return File.ReadAllBytes(modelFile); | ||||
} | ||||
|
||||
public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : | ||||
this(env, CheckFileAndRead(env, modelFile), inputs, outputs) | ||||
public TensorFlowTransform(IHostEnvironment env, string model, string[] inputs, string[] outputs, bool isFrozen = TensorFlowEstimator.Defaults.IsFrozen) : | ||||
this(env, GetSession(env, model, isFrozen), isFrozen, inputs, outputs, model) | ||||
{ | ||||
} | ||||
|
||||
private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs) | ||||
private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, bool isFrozen, string[] inputs, string[] outputs) : | ||||
this(env, LoadTFSession(env, modelBytes), isFrozen, inputs, outputs, null) | ||||
{ } | ||||
|
||||
private TensorFlowTransform(IHostEnvironment env, TFSession session, bool isFrozen, string[] inputs, string[] outputs, string exportDir) | ||||
{ | ||||
Contracts.CheckValue(env, nameof(env)); | ||||
_host = env.Register(nameof(RegistrationName)); | ||||
_host.CheckValue(modelBytes, nameof(modelBytes)); | ||||
Session = LoadTFSession(modelBytes); | ||||
Session = session; | ||||
IsFrozen = isFrozen; | ||||
ExportDir = exportDir; | ||||
foreach (var input in inputs) | ||||
{ | ||||
_host.CheckNonWhiteSpace(input, nameof(inputs)); | ||||
if (Session.Graph[input] == null) | ||||
throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model"); | ||||
var tfInput = new TFOutput(Session.Graph[input]); | ||||
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType)) | ||||
throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); | ||||
throw _host.ExceptParam(nameof(session), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow"); | ||||
} | ||||
|
||||
var newNames = new HashSet<string>(); | ||||
|
@@ -266,22 +316,40 @@ public void Save(ModelSaveContext ctx) | |||
_host.AssertValue(ctx); | ||||
ctx.CheckAtModel(); | ||||
ctx.SetVersionInfo(GetVersionInfo()); | ||||
ctx.Writer.Write(IsFrozen ? 1 : 0); | ||||
|
||||
// *** Binary format *** | ||||
// int: indicator for frozen models | ||||
// stream: tensorFlow model. | ||||
// int: number of input columns | ||||
// for each input column | ||||
// int: id of int column name | ||||
// int: number of output columns | ||||
// for each output column | ||||
// int: id of output column name | ||||
if (IsFrozen) | ||||
{ | ||||
var buffer = new TFBuffer(); | ||||
Session.Graph.ToGraphDef(buffer); | ||||
|
||||
var buffer = new TFBuffer(); | ||||
Session.Graph.ToGraphDef(buffer); | ||||
|
||||
ctx.SaveBinaryStream("TFModel", w => | ||||
ctx.SaveBinaryStream("TFModel", w => | ||||
{ | ||||
w.WriteByteArray(buffer.ToArray()); | ||||
}); | ||||
} | ||||
else | ||||
{ | ||||
w.WriteByteArray(buffer.ToArray()); | ||||
}); | ||||
var tempDirName = Path.GetFullPath(Path.Combine(Path.GetTempPath(), "_MLNET_TFTransform_" + Guid.NewGuid())); | ||||
var tempDir = Directory.CreateDirectory(tempDirName); | ||||
var tfZipFilePath = Path.Combine(tempDir.FullName, "tf_savedmodel.zip"); | ||||
|
||||
ZipFile.CreateFromDirectory(ExportDir, tfZipFilePath, CompressionLevel.Fastest, false); | ||||
byte[] byteArray = File.ReadAllBytes(tfZipFilePath); | ||||
ctx.SaveBinaryStream("TFSavedModel", w => | ||||
{ | ||||
w.WriteByteArray(byteArray); | ||||
}); | ||||
} | ||||
_host.AssertNonEmpty(Inputs); | ||||
ctx.Writer.Write(Inputs.Length); | ||||
foreach (var colName in Inputs) | ||||
|
@@ -527,8 +595,13 @@ public static CommonOutputs.TransformOutput TensorFlowScorer(IHostEnvironment en | |||
|
||||
public sealed class TensorFlowEstimator : TrivialEstimator<TensorFlowTransform> | ||||
{ | ||||
public TensorFlowEstimator(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) | ||||
: this(env, new TensorFlowTransform(env, modelFile, inputs, outputs)) | ||||
|
||||
public static class Defaults | ||||
{ | ||||
public const bool IsFrozen = true; | ||||
} | ||||
public TensorFlowEstimator(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs, bool isFrozen = Defaults.IsFrozen ) | ||||
: this(env, new TensorFlowTransform(env, modelFile, inputs, outputs, isFrozen)) | ||||
{ | ||||
} | ||||
|
||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A) Do we really need them? Because presence of Principal.Windows, for multiplatform code is suspicious.
B) If we really need them you need to make them part of nuget dependency, and also we need to modify https://github.com/dotnet/machinelearning/blob/master/build/Dependencies.props instead of specify versions here. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use this as a security measure on windows platform. We are using temp directory for loading TFSession for SavedModels. Models are considered executable code, so for this temp directory we need to ACL it in the high-rights process so low-rights process can’t access it.
I will look into making it part of nuget dependency
In reply to: 219572600 [](ancestors = 219572600)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use this one as reference implementation https://github.com/dotnet/machinelearning/blob/master/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj
In reply to: 219577320 [](ancestors = 219577320,219572600)