Skip to content

Enable TensorFlowTransform to work with pre-trained models that are not frozen #853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Sep 25, 2018
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
1007955
building transform from ground up
abgoswam Sep 6, 2018
40fbedc
dummy transform works after fixing the getters
abgoswam Sep 6, 2018
48d14c6
SavedModel format works for Train, but fails for Save&Predict
abgoswam Sep 6, 2018
35ff43a
remove dummy transform
abgoswam Sep 6, 2018
6291d0d
remove dummy unit test
abgoswam Sep 7, 2018
57508d3
Works with non-frozen models
abgoswam Sep 7, 2018
cfcd70f
building transform from ground up
abgoswam Sep 6, 2018
236de73
dummy transform works after fixing the getters
abgoswam Sep 6, 2018
781cff0
SavedModel format works for Train, but fails for Save&Predict
abgoswam Sep 6, 2018
07b15a0
remove dummy transform
abgoswam Sep 6, 2018
47f75b5
remove dummy unit test
abgoswam Sep 7, 2018
c304257
Merge branch 'abgoswam/tf_savedmodel' of https://github.com/abgoswam/…
abgoswam Sep 7, 2018
d0430b5
merge with master
abgoswam Sep 7, 2018
950a210
fix compilation issues; verify existing tests work fine
abgoswam Sep 7, 2018
97eb497
works locally; need to refactor code
abgoswam Sep 8, 2018
173729f
refactored code; keeping only 1 version of the convenience API
abgoswam Sep 10, 2018
292140b
Merge branch 'master' into abgoswam/tf_savedmodel
abgoswam Sep 10, 2018
655a8aa
added class for directory structure
abgoswam Sep 11, 2018
be5285a
using latest nuget package (0.0.3) for Microsoft.ML.TensorFlow.TestMo…
abgoswam Sep 11, 2018
46c04a3
delete temporary files used when loading/saving models
abgoswam Sep 11, 2018
e705f93
delete local models; the updated nuget version (0.0.3) for Microsoft.…
abgoswam Sep 11, 2018
84214e2
modified logic for load/restore of models
abgoswam Sep 12, 2018
04d02b8
modified logic for load&restore of unfrozen models
abgoswam Sep 13, 2018
89693bd
merge with latest dotnet/master
abgoswam Sep 13, 2018
8c8d92e
model version update to support non-frozen models
abgoswam Sep 13, 2018
d8edc64
based on the code review comments, we now infer if the provided model…
abgoswam Sep 13, 2018
eea524e
simplify the logic in Save() related to loading of SavedModel.
abgoswam Sep 13, 2018
b609ffd
trying Eric's suggestion
abgoswam Sep 13, 2018
74b8899
revert back to previous changes
abgoswam Sep 13, 2018
3382a83
attempt to use stream copy approach instead of in-memory
abgoswam Sep 14, 2018
aa8e844
taking care of some code review comments
abgoswam Sep 14, 2018
25b1e64
deleting some commented out code
abgoswam Sep 14, 2018
e32acca
Ensure we only copy the file segment & cleanup path logic
ericstj Sep 14, 2018
ac45539
added finalizer that closes the session (if it isn't closed) and dele…
abgoswam Sep 14, 2018
ce4efef
move away from using Dictionary<string, byte[]> and instead use strea…
abgoswam Sep 14, 2018
8b8764b
cleanup + misc review comments
abgoswam Sep 15, 2018
f955488
Merge branch 'master' into abgoswam/tf_savedmodel
abgoswam Sep 17, 2018
6e11f2c
trying to create temp dir with proper ACLs for high priviledge users
abgoswam Sep 19, 2018
ed71513
create temp dir with proper ACLs for high-privilege processes
abgoswam Sep 19, 2018
7df343d
Merge branch 'master' into abgoswam/tf_savedmodel
abgoswam Sep 19, 2018
f883d78
fix build after merge with latest master
abgoswam Sep 19, 2018
ae672d6
taking care of review comments related to model versioning of TFTrans…
abgoswam Sep 19, 2018
fac8dae
remove IDisposable from the TensorFlowTransform; renaming some methods
abgoswam Sep 20, 2018
2b1a576
refactor code so we have only 1 constructor for the TensorFlowTransfo…
abgoswam Sep 20, 2018
21879f6
merge with latest master
abgoswam Sep 20, 2018
a1d912d
fix issues with nuget packaging; refactored the code + added comments
abgoswam Sep 21, 2018
f6a1c84
add checks in code to make sure that the input is not a variable leng…
abgoswam Sep 21, 2018
5957c53
merge with latest master
abgoswam Sep 22, 2018
5120bb9
fix typo in name of package
abgoswam Sep 22, 2018
a624a3b
(1) added SavedModel test for MNIST model (2) added try/finally for d…
abgoswam Sep 24, 2018
8d9fdc5
remove and sort usings in file TrainSaveModelAndPredict.cs
abgoswam Sep 24, 2018
685ab99
using spaces in nupkgproj
abgoswam Sep 25, 2018
2d0ec1e
error checking for passed in IHostEnvironment
abgoswam Sep 25, 2018
8d8b986
fix TargetFramework version (netcore 2.0) of DnnAnalyzer to match tha…
abgoswam Sep 25, 2018
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Microsoft.ML.TensorFlow/Microsoft.ML.TensorFlow.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
<Compile Remove="AgTransform.cs" />
</ItemGroup>
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A) Do we really need them? Because presence of Principal.Windows, for multiplatform code is suspicious.
B) If we really need them you need to make them part of nuget dependency, and also we need to modify https://github.com/dotnet/machinelearning/blob/master/build/Dependencies.props instead of specify versions here. #Resolved

Copy link
Member Author

@abgoswam abgoswam Sep 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use this as a security measure on windows platform. We are using temp directory for loading TFSession for SavedModels. Models are considered executable code, so for this temp directory we need to ACL it in the high-rights process so low-rights process can’t access it.

I will look into making it part of nuget dependency


In reply to: 219572600 [](ancestors = 219572600)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use this one as reference implementation https://github.com/dotnet/machinelearning/blob/master/pkg/Microsoft.ML.ImageAnalytics/Microsoft.ML.ImageAnalytics.nupkgproj


In reply to: 219577320 [](ancestors = 219577320,219572600)


<ItemGroup>
<ProjectReference Include="..\Microsoft.ML.Core\Microsoft.ML.Core.csproj" />
<ProjectReference Include="..\Microsoft.ML.Data\Microsoft.ML.Data.csproj" />
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ public IEnumerable<DeviceAttributes> ListDevices(TFStatus status = null)
/// here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md
/// </para>
/// </remarks>
public TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
public static TFSession FromSavedModel(TFSessionOptions sessionOptions, TFBuffer runOptions, string exportDir, string[] tags, TFGraph graph, TFBuffer metaGraphDef, TFStatus status = null)
{
if (graph == null)
throw new ArgumentNullException(nameof(graph));
Expand Down
181 changes: 127 additions & 54 deletions src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Compression;
using System.Linq;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime;
Expand Down Expand Up @@ -36,21 +37,25 @@ public sealed class TensorFlowTransform : ITransformer, ICanSaveModel
{
public sealed class Arguments : TransformInputBase
{
[Argument(ArgumentType.Required, HelpText = "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", SortOrder = 0)]
public string Model;

[Argument(ArgumentType.Required, HelpText = "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.", ShortName = "model", SortOrder = 0)]
public string ModelFile;
[Argument(ArgumentType.AtMostOnce, HelpText = "Indicator for frozen models", ShortName = "frozen", SortOrder = 1)]
public bool IsFrozen = TensorFlowEstimator.Defaults.IsFrozen;
Copy link

@yaeldekel yaeldekel Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsFrozen [](start = 24, length = 8)

Is there a way to figure this out from the Model argument? For example, check if it is a directory or a file, or try loading it in a try/catch? #Resolved

Copy link
Member Author

@abgoswam abgoswam Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added a utility method TensorFlowUtils.IsFrozenTensorFlowModel() where we do this check for frozen / un-frozen based on whether model is provided as a single file (Frozen model) or directory (SavedModel)


In reply to: 217436951 [](ancestors = 217436951)


[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 1)]
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The names of the model inputs", ShortName = "inputs", SortOrder = 2)]
public string[] InputColumns;

[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 2)]
[Argument(ArgumentType.Multiple | ArgumentType.Required, HelpText = "The name of the outputs", ShortName = "outputs", SortOrder = 3)]
public string[] OutputColumns;
}

private readonly IHost _host;
private const string RegistrationName = "TensorFlowTransform";

internal readonly TFSession Session;
Copy link

@yaeldekel yaeldekel Sep 20, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

internal [](start = 8, length = 8)

Do all these fields need to be internal? #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored this, so we have 2 private fields (_savedModelPath and _isTemporarySavedModel)


In reply to: 219205265 [](ancestors = 219205265)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the other ones can be private too, it looks like they are only used by TensorFlowTransform and TensorFlowTransform.Mapper, which should be able to access TensorFlowTransform's private fields.


In reply to: 219275115 [](ancestors = 219275115,219205265)

internal readonly bool IsFrozen;
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Sep 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IsFrozen [](start = 31, length = 8)

You don't need them. All you care about is session, and you already have it. This is just details of how session been loaded. #Resolved

internal readonly string ExportDir;
internal readonly ColumnType[] OutputTypes;
internal readonly TFDataType[] TFOutputTypes;
internal readonly TFDataType[] TFInputTypes;
Expand Down Expand Up @@ -82,44 +87,17 @@ private static VersionInfo GetVersionInfo()
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="modelFile">This is the frozen TensorFlow model file. https://www.tensorflow.org/mobile/prepare_models </param>
/// <param name="name">Name of the output column. Keep it same as in the TensorFlow model.</param>
/// <param name="source">Name of the input column(s). Keep it same as in the TensorFlow model.</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string name, params string[] source)
{
return new TensorFlowTransform(env, modelFile, source, new[] { name }).MakeDataTransform(input);
}

/// <summary>
/// Convenience constructor for public facing API.
/// </summary>
/// <param name="env">Host Environment.</param>
/// <param name="input">Input <see cref="IDataView"/>. This is the output from previous transform or loader.</param>
/// <param name="modelFile">This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models </param>
/// <param name="model">This is the frozen tensorflow model file. https://www.tensorflow.org/mobile/prepare_models </param>
/// <param name="names">Name of the output column(s). Keep it same as in the Tensorflow model.</param>
/// <param name="source">Name of the input column(s). Keep it same as in the Tensorflow model.</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string modelFile, string[] names, string[] source)
/// <param name="isFrozen">Indicator for frozen models</param>
public static IDataTransform Create(IHostEnvironment env, IDataView input, string model, string[] names, string[] source, bool isFrozen = TensorFlowEstimator.Defaults.IsFrozen)
{
return new TensorFlowTransform(env, modelFile, source, names).MakeDataTransform(input);
return new TensorFlowTransform(env, model, source, names, isFrozen).MakeDataTransform(input);
}

// Factory method for SignatureLoadModel.
public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx)
private static Tuple<string[], string[]> ModelInputsOutputs(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());
// *** Binary format ***
// stream: tensorFlow model.
// int: number of input columns
// for each input column
// int: id of int column name
// int: number of output columns
// for each output column
// int: id of output column name
byte[] modelBytes = null;
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();
var numInputs = ctx.Reader.ReadInt32();
env.CheckDecode(numInputs > 0);
string[] inputs = new string[numInputs];
Expand All @@ -136,7 +114,49 @@ public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext
for (int j = 0; j < outputs.Length; j++)
outputs[j] = ctx.LoadNonEmptyString();

return new TensorFlowTransform(env, modelBytes, inputs, outputs);
return new Tuple<string[], string[]>(inputs, outputs);
}

// Factory method for SignatureLoadModel.
public static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext ctx)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ctx, nameof(ctx));
ctx.CheckAtModel(GetVersionInfo());

// *** Binary format ***
// int: indicator for frozen models
Copy link

@yaeldekel yaeldekel Sep 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int [](start = 15, length = 3)

Change to byte. #Closed

// stream: tensorFlow model.
// int: number of input columns
// for each input column
// int: id of int column name
// int: number of output columns
// for each output column
// int: id of output column name
var isFrozen = ctx.Reader.ReadInt32();
Copy link

@yaeldekel yaeldekel Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var isFrozen = ctx.Reader.ReadInt32(); [](start = 12, length = 38)

Need to increment the version number since the serialized format has changed. We can still read the older version by setting isFrozen=1 without reading anything. #Resolved

if (isFrozen == 1)
{
byte[] modelBytes = null;
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();

var io = ModelInputsOutputs(env, ctx);
Copy link

@yaeldekel yaeldekel Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

var io = ModelInputsOutputs(env, ctx); [](start = 16, length = 38)

This line can be moved before the if. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we move it before the if (line127) wont it break the order in which it was saved ?

(The model inputs/outputs come only after we are done reading the stream contents)


In reply to: 217559827 [](ancestors = 217559827)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the model is saved in a separate stream, I think it doesn't matter. The order is important only for what ends up in the Model.key file.


In reply to: 217571319 [](ancestors = 217571319,217559827)

return new TensorFlowTransform(env, modelBytes, (isFrozen == 1), io.Item1, io.Item2);
}
else
Copy link

@yaeldekel yaeldekel Sep 14, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else [](start = 12, length = 4)

Since you are returning inside the "if", you don't need to have "else" here. #Resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think it would be better to have just one constructor for TensorFlowTransform, that takes the session, inputs and outputs, and the saved model dictionary as parameters. I think the isFrozen parameter is also not needed, since it can be inferred from the fact that the dictionary is null.


In reply to: 217844206 [](ancestors = 217844206)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i have removed the "else" condition. also we are no longer using the dictionary approach to a stream based approach.

we can discuss further if we can simplify the code to just have one constructor


In reply to: 217845821 [](ancestors = 217845821,217844206)

Copy link
Member Author

@abgoswam abgoswam Sep 19, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we now have 2 constructors.


In reply to: 217869366 [](ancestors = 217869366,217845821,217844206)

{
Copy link
Member Author

@abgoswam abgoswam Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{ [](start = 11, length = 2)

[observation] The Train() method of the LearningPipeline API ends up calling Save() and Create() multiple times.

For un-frozen models we need a particular directory structure so we can create the TFSession correctly using the TFSession.FromSavedModel API (line 215)

In this call to Create() we do the following :
(a) load the binary stream
(b) write the binary stream out to a zip file
(c) extract the zip file to a temporary folder TempSavedModelDirName (_savedModel)
(d) create the session
(e) delete the zip file

We do not delete the temporary folder because in a subsequent call to Save() we are writing out the contents of the temporary folder as a byte stream.

So for this scenario, what we have currently is that we define a unique location for the temporary folder -- and it gets cleaned up only when required, inside the Create() call. Line 139.

Is there some way to ensure the temporary folder is always cleaned up ? Or perhaps some way to not have a temporary folder at all.

@ericstj @zeahmed @yaeldekel #Resolved

Copy link
Member

@ericstj ericstj Sep 12, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should avoid the temporary zip file. If you can get a seekable stream from ctx, just pass that into the ZipArchive constructor: https://msdn.microsoft.com/es-es/library/system.io.compression.ziparchive.ziparchive(v=vs.110).aspx. It looks like you can do this by accessing the BaseStream on the binary reader passed into the TryLoadBinaryStream action.

You should also avoid allocating an extra byte array for the zip and instead do all the reading within the callback from TryLoadBinaryStream, since it looks like TryLoadBinaryStream will close the stream:

. So within this callback, new up the archive, then call archive.ExtractToDirectory extension method: https://docs.microsoft.com/en-us/dotnet/api/system.io.compression.zipfileextensions.extracttodirectory?view=netframework-4.7.2.

I think you still need to have a temporary directory for the model files themselves. It looks like this is a requirement of TensorFlow https://github.com/tensorflow/tensorflow/blob/3be04971716fcaf0c11ad9262e60efa428553e14/tensorflow/c/c_api.h#L1367. It requires the model to be in a directory and doesn't provide a mechanism for loading from some sort of stream/memory representation. So you'll still need a temp path that you'll have to clean up.

Ideally you should provide a way for this path to be specified, rather than always call GetTempPath behind the scenes. I think ML.NET has a pattern for this, if not, perhaps consider a constructor overload that allows for specification of a temp path root. You may also consider some way to say "don't clean up the temp path".

Typically with things like this you should treat the temporary folder as state managed by the object and then cleaned up in the dispose/finalizer.


In reply to: 217123233 [](ancestors = 217123233)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried both of the suggestions above, but they seem to have some limitations with respect to this transform. I have since modified the logic of load&restore to do things in-memory.

Am documenting the issues hit with the above suggestion:

(1) the ZipArchive approach looked promising to avoid the temporary zip file. However it works on the entire stream, rather than on a subset of the stream.
(2) currently the TFTransform does not dispose the TFSession object. So attempts to cleanup the temp folder fail with 'file still in use' errors. The approach of using unique location for the temporary folder is problematic in general (e.g. during CV or if pipeline contains multiple TFTransforms)

I have modified the logic so we do the ML.NET serialization/de-serialization using in-memory data structures.


In reply to: 217145506 [](ancestors = 217145506,217123233)

Copy link
Member

@ericstj ericstj Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should file an issue on #2.

I think this approach looks much better than zip-in-a-zip approach before. One change I'd suggest is that instead of reading to byte array and storing in memory, you create the file within the callback (var fs = new FileStream(...)) and use br.BaseStream.CopyTo(fs, fileLength); To do this you'll need to replace readbytearray/writebytearray with a your own calls that store file length, followed by the stream copy.

One thing I notice that could be problematic for the frozen case above (not related to this change) is that it seems to use Graph.Import to import a entire model as a managed byte array. I can imagine that this will potentially hit the upper limit of managed array sizes for large models. I see that TF's graph import functionality expects a single buffer (https://github.com/tensorflow/tensorflow/blob/3be04971716fcaf0c11ad9262e60efa428553e14/tensorflow/c/c_api.h#L1018-L1020).


In reply to: 217447210 [](ancestors = 217447210,217145506,217123233)

Copy link
Member Author

@abgoswam abgoswam Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for reading to byte array and storing in memory is to handle the scenario where we saw the sequence of Save1() -> Create1() -> Save2() -> Create2() being called

When Create1() is called, we load up the contents into a dictionary. So when Save2() is called we use the dictionary to write out the contents.

Because of issue #906 (not being able to clean-up unmanaged resources ) I am not sure we can do away with the in-memory approach. Thoughts ?


In reply to: 217462971 [](ancestors = 217462971,217447210,217145506,217123233)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adopted the stream copy based approach instead of reading to byte array

also added a finalizer that closes the session (if it isn’t closed) and deletes the temporary directory


In reply to: 217545311 [](ancestors = 217545311,217462971,217447210,217145506,217123233)

Copy link

@yaeldekel yaeldekel Sep 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eerhardt , @danmosemsft , would this be a security risk? To load the TensorFlow model here, we create a temporary directory and copy some files to it, and then load the model from that directory. #Resolved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

followed the recommendation of the .NET security team to address this risk


In reply to: 218159541 [](ancestors = 218159541)

// Load model binary
byte[] tfFilesBin = null;
var load = ctx.TryLoadBinaryStream("TFSavedModel", br => tfFilesBin = br.ReadByteArray());
var tempDirName = Path.GetFullPath(Path.Combine(Path.GetTempPath(), "_MLNET_TFTransform_" + Guid.NewGuid()));
var tempDir = Directory.CreateDirectory(tempDirName);
var tfZipFilePath = Path.Combine(tempDir.FullName, "tf_savedmodel.zip");
File.WriteAllBytes(tfZipFilePath, tfFilesBin);
ZipFile.ExtractToDirectory(tfZipFilePath, Path.Combine(tempDir.FullName, "tf_savedmodel"));

var io = ModelInputsOutputs(env, ctx);
return new TensorFlowTransform(env, Path.Combine(tempDir.FullName, "tf_savedmodel"), io.Item1, io.Item2, (isFrozen == 1));
}
}

// Factory method for SignatureDataTransform.
Expand All @@ -147,7 +167,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV
env.CheckValue(input, nameof(input));
env.CheckValue(args.InputColumns, nameof(args.InputColumns));
env.CheckValue(args.OutputColumns, nameof(args.OutputColumns));
return new TensorFlowTransform(env, args.ModelFile, args.InputColumns, args.OutputColumns).MakeDataTransform(input);
return new TensorFlowTransform(env, args.Model, args.InputColumns,args.OutputColumns, args.IsFrozen).MakeDataTransform(input);
}

// Factory method for SignatureLoadDataTransform.
Expand All @@ -158,7 +178,7 @@ public static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
public static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);

private TFSession LoadTFSession(byte[] modelBytes)
private static TFSession LoadTFSession(IHostEnvironment env, byte[] modelBytes)
{
var graph = new TFGraph();
try
Expand All @@ -168,38 +188,68 @@ private TFSession LoadTFSession(byte[] modelBytes)
catch (Exception ex)
{
#pragma warning disable MSML_NoMessagesForLoadContext
throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
throw env.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
#pragma warning restore MSML_NoMessagesForLoadContext
}
return new TFSession(graph);
}

private static TFSession LoadTFSession(string exportDirSavedModel)
{
var sessionOptions = new TFSessionOptions();
var exportDir = exportDirSavedModel;
var tags = new string[] { "serve" };
var graph = new TFGraph();
var metaGraphDef = new TFBuffer();

var session = TFSession.FromSavedModel(sessionOptions, null, exportDir, tags, graph, metaGraphDef);
return session;
Copy link

@yaeldekel yaeldekel Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return session; [](start = 12, length = 15)

nit: you can directly return from the line above. #Resolved

}

private static TFSession GetSession(IHostEnvironment env, string model, bool isFrozen)
{
if (isFrozen)
{
byte[] modelBytes = CheckFileAndRead(env, model);
return LoadTFSession(env, modelBytes);
}
else
Copy link

@yaeldekel yaeldekel Sep 13, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else [](start = 12, length = 4)

The "else" is not needed here. #Resolved

{
return LoadTFSession(model);
}
}

private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile)
{
env.CheckNonWhiteSpace(modelFile, nameof(modelFile));
env.CheckUserArg(File.Exists(modelFile), nameof(modelFile));
return File.ReadAllBytes(modelFile);
}

public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) :
this(env, CheckFileAndRead(env, modelFile), inputs, outputs)
public TensorFlowTransform(IHostEnvironment env, string model, string[] inputs, string[] outputs, bool isFrozen = TensorFlowEstimator.Defaults.IsFrozen) :
this(env, GetSession(env, model, isFrozen), isFrozen, inputs, outputs, model)
{
}

private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs)
private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, bool isFrozen, string[] inputs, string[] outputs) :
this(env, LoadTFSession(env, modelBytes), isFrozen, inputs, outputs, null)
{ }

private TensorFlowTransform(IHostEnvironment env, TFSession session, bool isFrozen, string[] inputs, string[] outputs, string exportDir)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(RegistrationName));
_host.CheckValue(modelBytes, nameof(modelBytes));
Session = LoadTFSession(modelBytes);
Session = session;
IsFrozen = isFrozen;
ExportDir = exportDir;
foreach (var input in inputs)
{
_host.CheckNonWhiteSpace(input, nameof(inputs));
if (Session.Graph[input] == null)
throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model");
var tfInput = new TFOutput(Session.Graph[input]);
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
throw _host.ExceptParam(nameof(session), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
}

var newNames = new HashSet<string>();
Expand Down Expand Up @@ -266,22 +316,40 @@ public void Save(ModelSaveContext ctx)
_host.AssertValue(ctx);
ctx.CheckAtModel();
ctx.SetVersionInfo(GetVersionInfo());
ctx.Writer.Write(IsFrozen ? 1 : 0);

// *** Binary format ***
// int: indicator for frozen models
// stream: tensorFlow model.
// int: number of input columns
// for each input column
// int: id of int column name
// int: number of output columns
// for each output column
// int: id of output column name
if (IsFrozen)
{
var buffer = new TFBuffer();
Session.Graph.ToGraphDef(buffer);

var buffer = new TFBuffer();
Session.Graph.ToGraphDef(buffer);

ctx.SaveBinaryStream("TFModel", w =>
ctx.SaveBinaryStream("TFModel", w =>
{
w.WriteByteArray(buffer.ToArray());
});
}
else
{
w.WriteByteArray(buffer.ToArray());
});
var tempDirName = Path.GetFullPath(Path.Combine(Path.GetTempPath(), "_MLNET_TFTransform_" + Guid.NewGuid()));
var tempDir = Directory.CreateDirectory(tempDirName);
var tfZipFilePath = Path.Combine(tempDir.FullName, "tf_savedmodel.zip");

ZipFile.CreateFromDirectory(ExportDir, tfZipFilePath, CompressionLevel.Fastest, false);
byte[] byteArray = File.ReadAllBytes(tfZipFilePath);
ctx.SaveBinaryStream("TFSavedModel", w =>
{
w.WriteByteArray(byteArray);
});
}
_host.AssertNonEmpty(Inputs);
ctx.Writer.Write(Inputs.Length);
foreach (var colName in Inputs)
Expand Down Expand Up @@ -527,8 +595,13 @@ public static CommonOutputs.TransformOutput TensorFlowScorer(IHostEnvironment en

public sealed class TensorFlowEstimator : TrivialEstimator<TensorFlowTransform>
{
public TensorFlowEstimator(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs)
: this(env, new TensorFlowTransform(env, modelFile, inputs, outputs))

public static class Defaults
{
public const bool IsFrozen = true;
}
public TensorFlowEstimator(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs, bool isFrozen = Defaults.IsFrozen )
: this(env, new TensorFlowTransform(env, modelFile, inputs, outputs, isFrozen))
{
}

Expand Down
9 changes: 7 additions & 2 deletions src/Microsoft.ML/CSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15785,9 +15785,14 @@ public sealed partial class TensorFlowScorer : Microsoft.ML.Runtime.EntryPoints.


/// <summary>
/// This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
/// TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.
/// </summary>
public string ModelFile { get; set; }
public string Model { get; set; }

/// <summary>
/// Indicator for frozen models
/// </summary>
public bool IsFrozen { get; set; } = true;

/// <summary>
/// The names of the model inputs
Expand Down
37 changes: 23 additions & 14 deletions test/BaselineOutput/Common/EntryPoints/core_manifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -21716,14 +21716,31 @@
"ShortName": "TFTransform",
"Inputs": [
{
"Name": "ModelFile",
"Name": "Model",
"Type": "String",
"Desc": "This is the frozen protobuf model file. Please see https://www.tensorflow.org/mobile/prepare_models for more details.",
"Desc": "TensorFlow model used by the transform. Please see https://www.tensorflow.org/mobile/prepare_models for more details.",
"Required": true,
"SortOrder": 0.0,
"IsNullable": false
},
{
"Name": "IsFrozen",
"Type": "Bool",
"Desc": "Indicator for frozen models",
"Aliases": [
"model"
"frozen"
],
"Required": false,
"SortOrder": 1.0,
"IsNullable": false,
"Default": true
},
{
"Name": "Data",
"Type": "DataView",
"Desc": "Input dataset",
"Required": true,
"SortOrder": 0.0,
"SortOrder": 1.0,
"IsNullable": false
},
{
Expand All @@ -21737,15 +21754,7 @@
"inputs"
],
"Required": true,
"SortOrder": 1.0,
"IsNullable": false
},
{
"Name": "Data",
"Type": "DataView",
"Desc": "Input dataset",
"Required": true,
"SortOrder": 1.0,
"SortOrder": 2.0,
"IsNullable": false
},
{
Expand All @@ -21759,7 +21768,7 @@
"outputs"
],
"Required": true,
"SortOrder": 2.0,
"SortOrder": 3.0,
"IsNullable": false
}
],
Expand Down
Loading