Skip to content

Commit ce0c917

Browse files
authored
Fixed a tensorflow test which was marked as skipped. (#2855)
1 parent f856d09 commit ce0c917

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

src/Microsoft.ML.TensorFlow/TensorFlowModel.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public DataViewSchema GetInputSchema()
6363
/// </format>
6464
/// </example>
6565
public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName)
66-
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, ModelPath);
66+
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, this);
6767

6868
/// <summary>
6969
/// Scores a dataset using a pre-traiend TensorFlow model.
@@ -78,7 +78,7 @@ public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string
7878
/// </format>
7979
/// </example>
8080
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames)
81-
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, ModelPath);
81+
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, this);
8282

8383
/// <summary>
8484
/// Retrain the TensorFlow model on new data.

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

+20-5
Original file line numberDiff line numberDiff line change
@@ -374,17 +374,32 @@ public void TensorFlowTransformObjectDetectionTest()
374374
[Fact(Skip = "Model files are not available yet")]
375375
public void TensorFlowTransformInceptionTest()
376376
{
377-
var modelLocation = @"C:\models\TensorFlow\tensorflow_inception_graph.pb";
377+
string inputName = "input";
378+
string outputName = "softmax2_pre_activation";
379+
var modelLocation = @"inception5h\tensorflow_inception_graph.pb";
378380
var mlContext = new MLContext(seed: 1);
379381
var dataFile = GetDataPath("images/images.tsv");
380382
var imageFolder = Path.GetDirectoryName(dataFile);
381-
var data = mlContext.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile));
383+
var reader = mlContext.Data.CreateTextLoader(
384+
columns: new[]
385+
{
386+
new TextLoader.Column("ImagePath", DataKind.String , 0),
387+
new TextLoader.Column("Name", DataKind.String, 1)
388+
389+
},
390+
hasHeader: false,
391+
allowSparse: false
392+
);
393+
394+
var data = reader.Load(new MultiFileSource(dataFile));
382395
var images = mlContext.Transforms.LoadImages(imageFolder, ("ImageReal", "ImagePath")).Fit(data).Transform(data);
383396
var cropped = mlContext.Transforms.ResizeImages("ImageCropped", 224, 224, "ImageReal").Fit(images).Transform(images);
384-
var pixels = mlContext.Transforms.ExtractPixels("input", "ImageCropped").Fit(cropped).Transform(cropped);
385-
var tf = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel("softmax2_pre_activation", "input").Fit(pixels).Transform(pixels);
397+
var pixels = mlContext.Transforms.ExtractPixels(inputName, "ImageCropped").Fit(cropped).Transform(cropped);
398+
var tf = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(outputName, inputName).Fit(pixels).Transform(pixels);
386399

387-
using (var curs = tf.GetRowCursor(tf.Schema["input"], tf.Schema["softmax2_pre_activation"]))
400+
tf.Schema.TryGetColumnIndex(inputName, out int input);
401+
tf.Schema.TryGetColumnIndex(outputName, out int b);
402+
using (var curs = tf.GetRowCursor(tf.Schema[inputName], tf.Schema[outputName]))
388403
{
389404
var get = curs.GetGetter<VBuffer<float>>(tf.Schema["softmax2_pre_activation"]);
390405
var getInput = curs.GetGetter<VBuffer<float>>(tf.Schema["input"]);

0 commit comments

Comments
 (0)