Skip to content

Commit c4cd6b7

Browse files
author
Pete Luferenko
committed
Merge remote-tracking branch 'upstream/master' into feature/normalizers
2 parents 85bee22 + a3b67d3 commit c4cd6b7

File tree

3 files changed

+94
-19
lines changed

3 files changed

+94
-19
lines changed

src/Microsoft.ML.TensorFlow/doc.xml

+6-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
The TensorflowTransform extracts the specified output from the operation computed on the graph (given the input(s)) using a pre-trained <a href="https://www.tensorflow.org">Tensorflow</a> model.
1111
The transform takes as input the Tensorflow model together with the names of the inputs to the model and name of the operation for which output values will be extracted from the model.
1212

13+
This transform requires the <a href="https://dotnet.myget.org/feed/dotnet-core/package/nuget/Microsoft.ML.TensorFlow/0.5.0-preview-26830-5">Microsoft.ML.TensorFlow</a> nuget to be installed.
14+
1315
The TensorflowTransform has following assumptions regarding the input, output and processing of data.
1416
<list type="number">
1517
<item>
@@ -23,6 +25,9 @@
2325
Upon success, the transform will introduce a new column in <see cref="IDataView"/> based on the name of the output column specified.
2426
</item>
2527
</list>
28+
29+
The inputs and outputs of a TensorFlow model can be obtained using the <a href="https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/graph_transforms/README.md#inspecting-graphs"><code>summarize_graph</code> tool</a>.
30+
2631
</remarks>
2732
</member>
2833
<example name="TensorflowTransform">
@@ -71,4 +76,4 @@
7176
</example>
7277

7378
</members>
74-
</doc>
79+
</doc>

test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs

+36-17
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
using Microsoft.ML.Runtime.Data;
88
using Microsoft.ML.Runtime.ImageAnalytics;
99
using Microsoft.ML.Runtime.LightGBM;
10+
using Microsoft.ML.Trainers;
1011
using Microsoft.ML.Transforms;
1112
using System.Collections.Generic;
1213
using System.IO;
@@ -16,7 +17,7 @@ namespace Microsoft.ML.Scenarios
1617
{
1718
public partial class ScenariosTests
1819
{
19-
[Fact(Skip = "Disabled due to this bug https://github.com/dotnet/machinelearning/issues/770")]
20+
[Fact]
2021
public void TensorFlowTransformCifarLearningPipelineTest()
2122
{
2223
var imageHeight = 32;
@@ -52,23 +53,35 @@ public void TensorFlowTransformCifarLearningPipelineTest()
5253
OutputColumn = "Output"
5354
});
5455

55-
using (var environment = new TlcEnvironment())
56+
pipeline.Add(new ColumnConcatenator(outputColumn: "Features", "Output"));
57+
pipeline.Add(new TextToKeyConverter("Label"));
58+
pipeline.Add(new StochasticDualCoordinateAscentClassifier());
59+
60+
var model = pipeline.Train<CifarData, CifarPrediction>();
61+
string[] scoreLabels;
62+
model.TryGetScoreLabelNames(out scoreLabels);
63+
64+
Assert.NotNull(scoreLabels);
65+
Assert.Equal(3, scoreLabels.Length);
66+
Assert.Equal("banana", scoreLabels[0]);
67+
Assert.Equal("hotdog", scoreLabels[1]);
68+
Assert.Equal("tomato", scoreLabels[2]);
69+
70+
CifarPrediction prediction = model.Predict(new CifarData()
5671
{
57-
IDataView trans = pipeline.Execute(environment);
58-
Assert.NotNull(trans);
72+
ImagePath = GetDataPath("images/banana.jpg")
73+
});
74+
Assert.Equal(1, prediction.PredictedLabels[0], 2);
75+
Assert.Equal(0, prediction.PredictedLabels[1], 2);
76+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
5977

60-
trans.Schema.TryGetColumnIndex("Output", out int output);
61-
using (var cursor = trans.GetRowCursor(col => col == output))
62-
{
63-
var buffer = default(VBuffer<float>);
64-
var getter = cursor.GetGetter<VBuffer<float>>(output);
65-
while (cursor.MoveNext())
66-
{
67-
getter(ref buffer);
68-
Assert.Equal(10, buffer.Length);
69-
}
70-
}
71-
}
78+
prediction = model.Predict(new CifarData()
79+
{
80+
ImagePath = GetDataPath("images/hotdog.jpg")
81+
});
82+
Assert.Equal(0, prediction.PredictedLabels[0], 2);
83+
Assert.Equal(1, prediction.PredictedLabels[1], 2);
84+
Assert.Equal(0, prediction.PredictedLabels[2], 2);
7285
}
7386
}
7487

@@ -78,6 +91,12 @@ public class CifarData
7891
public string ImagePath;
7992

8093
[Column("1")]
81-
public string Name;
94+
public string Label;
95+
}
96+
97+
public class CifarPrediction
98+
{
99+
[ColumnName("Score")]
100+
public float[] PredictedLabels;
82101
}
83102
}

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

+52-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ public void TensorFlowTransformMNISTConvTest()
119119
var metrics = Evaluate(env, testDataScorer);
120120

121121
Assert.Equal(0.99, metrics.AccuracyMicro, 2);
122-
Assert.Equal(0.99, metrics.AccuracyMicro, 2);
122+
Assert.Equal(1.0, metrics.AccuracyMacro, 2);
123123

124124
// Create prediction engine and test predictions
125125
var model = env.CreatePredictionEngine<MNISTData, MNISTPrediction>(testDataScorer);
@@ -215,12 +215,63 @@ public void TensorFlowTransformCifar()
215215
{
216216
var buffer = default(VBuffer<float>);
217217
var getter = cursor.GetGetter<VBuffer<float>>(output);
218+
var numRows = 0;
218219
while (cursor.MoveNext())
219220
{
220221
getter(ref buffer);
221222
Assert.Equal(10, buffer.Length);
223+
numRows += 1;
224+
}
225+
Assert.Equal(3, numRows);
226+
}
227+
}
228+
}
229+
230+
[Fact]
231+
public void TensorFlowTransformCifarInvalidShape()
232+
{
233+
var model_location = "cifar_model/frozen_model.pb";
234+
235+
using (var env = new TlcEnvironment())
236+
{
237+
var imageHeight = 28;
238+
var imageWidth = 28;
239+
var dataFile = GetDataPath("images/images.tsv");
240+
var imageFolder = Path.GetDirectoryName(dataFile);
241+
var data = env.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
242+
243+
var images = ImageLoaderTransform.Create(env, new ImageLoaderTransform.Arguments()
244+
{
245+
Column = new ImageLoaderTransform.Column[1]
246+
{
247+
new ImageLoaderTransform.Column() { Source= "ImagePath", Name="ImageReal" }
248+
},
249+
ImageFolder = imageFolder
250+
}, data);
251+
var cropped = ImageResizerTransform.Create(env, new ImageResizerTransform.Arguments()
252+
{
253+
Column = new ImageResizerTransform.Column[1]{
254+
new ImageResizerTransform.Column() { Source = "ImageReal", Name= "ImageCropped", ImageHeight =imageHeight, ImageWidth = imageWidth, Resizing = ImageResizerTransform.ResizingKind.IsoCrop}
255+
}
256+
}, images);
257+
258+
var pixels = ImagePixelExtractorTransform.Create(env, new ImagePixelExtractorTransform.Arguments()
259+
{
260+
Column = new ImagePixelExtractorTransform.Column[1]{
261+
new ImagePixelExtractorTransform.Column() { Source= "ImageCropped", Name = "Input", UseAlpha=false, InterleaveArgb=true}
222262
}
263+
}, cropped);
264+
265+
var thrown = false;
266+
try
267+
{
268+
IDataView trans = TensorFlowTransform.Create(env, pixels, model_location, "Output", "Input");
269+
}
270+
catch
271+
{
272+
thrown = true;
223273
}
274+
Assert.True(thrown);
224275
}
225276
}
226277
}

0 commit comments

Comments
 (0)