Skip to content

Commit 57508d3

Browse files
committed
Works with non-frozen models
1 parent 6291d0d commit 57508d3

File tree

2 files changed

+94
-38
lines changed

2 files changed

+94
-38
lines changed

src/Microsoft.ML.TensorFlow/TensorflowTransform.cs

+79-23
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
using System;
66
using System.IO;
7+
using System.IO.Compression;
78
using System.Linq;
89
using Microsoft.ML.Runtime;
910
using Microsoft.ML.Runtime.CommandLine;
@@ -41,6 +42,8 @@ internal sealed class TensorFlowMapper : IRowMapper
4142
private readonly bool[] _isVectorInput;
4243
private readonly TFShape[] _tfInputShapes;
4344
private readonly TFDataType[] _tfInputTypes;
45+
private readonly bool _isFrozen;
46+
private readonly string _exportDir;
4447

4548
private readonly string _outputColName;
4649
private readonly ColumnType _outputColType;
@@ -66,7 +69,7 @@ public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, byte[] modelB
6669
_host.CheckNonEmpty(modelBytes, nameof(modelBytes));
6770
_host.CheckNonEmpty(inputColNames, nameof(inputColNames));
6871
_host.CheckNonEmpty(outputColName, nameof(outputColName));
69-
72+
_isFrozen = true;
7073
_session = LoadTFSession(modelBytes, null);
7174
_host.CheckValue(_session.Graph[outputColName], nameof(outputColName), "Output does not exist in the model");
7275
_host.Check(inputColNames.All(name => _session.Graph[name] != null), "One of the input does not exist in the model");
@@ -83,6 +86,8 @@ public TensorFlowMapper(IHostEnvironment env, ISchema inputSchema, string export
8386
_host.CheckValue(inputSchema, nameof(inputSchema));
8487
_host.CheckNonEmpty(inputColNames, nameof(inputColNames));
8588
_host.CheckNonEmpty(outputColName, nameof(outputColName));
89+
_isFrozen = false;
90+
_exportDir = exportDir;
8691

8792
_session = LoadTFSession(exportDir);
8893
_host.CheckValue(_session.Graph[outputColName], nameof(outputColName), "Output does not exist in the model");
@@ -99,41 +104,92 @@ public static TensorFlowMapper Create(IHostEnvironment env, ModelLoadContext ctx
99104
env.CheckValue(ctx, nameof(ctx));
100105
ctx.CheckAtModel(GetVersionInfo());
101106

102-
var numInputs = ctx.Reader.ReadInt32();
103-
Contracts.CheckDecode(numInputs > 0);
107+
var isFrozen = ctx.Reader.ReadInt32();
108+
if (isFrozen==1)
109+
{
110+
var numInputs = ctx.Reader.ReadInt32();
111+
Contracts.CheckDecode(numInputs > 0);
112+
113+
string[] source = new string[numInputs];
114+
for (int j = 0; j < source.Length; j++)
115+
source[j] = ctx.LoadNonEmptyString();
116+
117+
byte[] data = null;
118+
if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray()))
119+
throw env.ExceptDecode();
120+
121+
var outputColName = ctx.LoadNonEmptyString();
122+
123+
return new TensorFlowMapper(env, schema, data, source, outputColName);
124+
}
125+
else
126+
{
127+
var numInputs = ctx.Reader.ReadInt32();
128+
Contracts.CheckDecode(numInputs > 0);
129+
130+
string[] source = new string[numInputs];
131+
for (int j = 0; j < source.Length; j++)
132+
source[j] = ctx.LoadNonEmptyString();
133+
134+
// Load model binary
135+
byte[] tfFilesBin = null;
136+
var load = ctx.TryLoadBinaryStream("TFSavedModel", br => tfFilesBin = br.ReadByteArray());
104137

105-
string[] source = new string[numInputs];
106-
for (int j = 0; j < source.Length; j++)
107-
source[j] = ctx.LoadNonEmptyString();
138+
var tempDirName = Path.GetFullPath(Path.Combine(Path.GetTempPath(), "_MLNET_TFTransform_" + Guid.NewGuid()));
139+
var tempDir = Directory.CreateDirectory(tempDirName);
140+
var tfZipFilePath = Path.Combine(tempDir.FullName, "tf_savedmodel.zip");
108141

109-
byte[] data = null;
110-
if (!ctx.TryLoadBinaryStream("TFModel", r => data = r.ReadByteArray()))
111-
throw env.ExceptDecode();
142+
File.WriteAllBytes(tfZipFilePath, tfFilesBin);
143+
ZipFile.ExtractToDirectory(tfZipFilePath, Path.Combine(tempDir.FullName, "tf_savedmodel"));
112144

113-
var outputColName = ctx.LoadNonEmptyString();
145+
var outputColName = ctx.LoadNonEmptyString();
114146

115-
return new TensorFlowMapper(env, schema, data, source, outputColName);
147+
return new TensorFlowMapper(env, schema, Path.Combine(tempDir.FullName, "tf_savedmodel"), source, outputColName);
148+
}
116149
}
117150

118151
public void Save(ModelSaveContext ctx)
119152
{
120153
_host.AssertValue(ctx);
121154
ctx.CheckAtModel();
122155
ctx.SetVersionInfo(GetVersionInfo());
156+
ctx.Writer.Write(_isFrozen ? 1 : 0);
157+
if (_isFrozen)
158+
{
159+
var buffer = new TFBuffer();
160+
_session.Graph.ToGraphDef(buffer);
123161

124-
var buffer = new TFBuffer();
125-
_session.Graph.ToGraphDef(buffer);
126-
127-
ctx.SaveBinaryStream("TFModel", w =>
162+
ctx.SaveBinaryStream("TFModel", w =>
163+
{
164+
w.WriteByteArray(buffer.ToArray());
165+
});
166+
Contracts.AssertNonEmpty(_inputColNames);
167+
ctx.Writer.Write(_inputColNames.Length);
168+
foreach (var colName in _inputColNames)
169+
ctx.SaveNonEmptyString(colName);
170+
171+
ctx.SaveNonEmptyString(_outputColName);
172+
}
173+
else
128174
{
129-
w.WriteByteArray(buffer.ToArray());
130-
});
131-
Contracts.AssertNonEmpty(_inputColNames);
132-
ctx.Writer.Write(_inputColNames.Length);
133-
foreach (var colName in _inputColNames)
134-
ctx.SaveNonEmptyString(colName);
135-
136-
ctx.SaveNonEmptyString(_outputColName);
175+
var tempDirName = Path.GetFullPath(Path.Combine(Path.GetTempPath(), "_MLNET_TFTransform_" + Guid.NewGuid()));
176+
var tempDir = Directory.CreateDirectory(tempDirName);
177+
var tfZipFilePath = Path.Combine(tempDir.FullName, "tf_savedmodel.zip");
178+
179+
ZipFile.CreateFromDirectory(_exportDir, tfZipFilePath, CompressionLevel.Fastest, false);
180+
byte[] byteArray = File.ReadAllBytes(tfZipFilePath);
181+
ctx.SaveBinaryStream("TFSavedModel", w =>
182+
{
183+
w.WriteByteArray(byteArray);
184+
});
185+
186+
Contracts.AssertNonEmpty(_inputColNames);
187+
ctx.Writer.Write(_inputColNames.Length);
188+
foreach (var colName in _inputColNames)
189+
ctx.SaveNonEmptyString(colName);
190+
191+
ctx.SaveNonEmptyString(_outputColName);
192+
}
137193
}
138194

139195
private TFSession LoadTFSession(byte[] modelBytes, string modelArg)

test/Microsoft.ML.Tests/Scenarios/PipelineApi/TrainSaveModelAndPredict.cs

+15-15
Original file line numberDiff line numberDiff line change
@@ -102,21 +102,21 @@ public async void TensorFlowTransformTrainSaveModelAndPredict()
102102
Assert.Equal("hotdog", scoreLabels[1]);
103103
Assert.Equal("tomato", scoreLabels[2]);
104104

105-
//CifarPrediction prediction = loadedModel.Predict(new CifarData()
106-
//{
107-
// ImagePath = GetDataPath("images/banana.jpg")
108-
//});
109-
//Assert.Equal(1, prediction.PredictedLabels[0], 2);
110-
//Assert.Equal(0, prediction.PredictedLabels[1], 2);
111-
//Assert.Equal(0, prediction.PredictedLabels[2], 2);
112-
113-
//prediction = loadedModel.Predict(new CifarData()
114-
//{
115-
// ImagePath = GetDataPath("images/hotdog.jpg")
116-
//});
117-
//Assert.Equal(0, prediction.PredictedLabels[0], 2);
118-
//Assert.Equal(1, prediction.PredictedLabels[1], 2);
119-
//Assert.Equal(0, prediction.PredictedLabels[2], 2);
105+
CifarPrediction prediction = loadedModel.Predict(new CifarData()
106+
{
107+
ImagePath = GetDataPath("images/banana.jpg")
108+
});
109+
Assert.Equal(1, prediction.PredictedLabels[0], 2);
110+
Assert.Equal(0, prediction.PredictedLabels[1], 2);
111+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
112+
113+
prediction = loadedModel.Predict(new CifarData()
114+
{
115+
ImagePath = GetDataPath("images/hotdog.jpg")
116+
});
117+
Assert.Equal(0, prediction.PredictedLabels[0], 2);
118+
Assert.Equal(1, prediction.PredictedLabels[1], 2);
119+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
120120
}
121121
}
122122
}

0 commit comments

Comments
 (0)