Skip to content

Commit f78af89

Browse files
authored
Internalization of TensorFlowUtils.cs and refactored TensorFlowCatalog. (#2672)
1 parent 2b417bb commit f78af89

File tree

13 files changed

+246
-271
lines changed

13 files changed

+246
-271
lines changed

docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/ImageClassification.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ public static void Example()
2020
var idv = mlContext.Data.LoadFromEnumerable(data);
2121

2222
// Create a ML pipeline.
23-
var pipeline = mlContext.Transforms.ScoreTensorFlowModel(
24-
modelLocation,
23+
var pipeline = mlContext.Model.LoadTensorFlowModel(modelLocation).ScoreTensorFlowModel(
2524
new[] { nameof(OutputScores.output) },
2625
new[] { nameof(TensorData.input) });
2726

docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/TextClassification.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ public static void Example()
4545
// Load the TensorFlow model once.
4646
// - Use it for quering the schema for input and output in the model
4747
// - Use it for prediction in the pipeline.
48-
var modelInfo = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation);
49-
var schema = modelInfo.GetModelSchema();
48+
var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(modelLocation);
49+
var schema = tensorFlowModel.GetModelSchema();
5050
var featuresType = (VectorType)schema["Features"].Type;
5151
Console.WriteLine("Name: {0}, Type: {1}, Shape: (-1, {2})", "Features", featuresType.ItemType.RawType, featuresType.Dimensions[0]);
5252
var predictionType = (VectorType)schema["Prediction/Softmax"].Type;
@@ -72,7 +72,7 @@ public static void Example()
7272
var engine = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text")
7373
.Append(mlContext.Transforms.Conversion.ValueMap(lookupMap, "Words", "Ids", new ColumnOptions[] { ("VariableLenghtFeatures", "TokenizedWords") }))
7474
.Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize"))
75-
.Append(mlContext.Transforms.ScoreTensorFlowModel(modelInfo, new[] { "Prediction/Softmax" }, new[] { "Features" }))
75+
.Append(tensorFlowModel.ScoreTensorFlowModel(new[] { "Prediction/Softmax" }, new[] { "Features" }))
7676
.Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax")))
7777
.Fit(dataView)
7878
.CreatePredictionEngine<IMDBSentiment, OutputScores>(mlContext);

src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs

+34-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
// See the LICENSE file in the project root for more information.
44

55
using System;
6-
using Microsoft.ML.Transforms.TensorFlow;
6+
using System.Collections.Generic;
7+
using System.Linq;
8+
using Microsoft.Data.DataView;
9+
using Microsoft.ML.Data;
710

811
namespace Microsoft.ML.DnnAnalyzer
912
{
@@ -17,11 +20,40 @@ public static void Main(string[] args)
1720
return;
1821
}
1922

20-
foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(new MLContext(), args[0]))
23+
foreach (var (name, opType, type, inputs) in GetModelNodes(args[0]))
2124
{
2225
var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
2326
Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");
2427
}
2528
}
29+
30+
private static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(string modelPath)
31+
{
32+
var mlContext = new MLContext();
33+
var tensorFlowModel = mlContext.Model.LoadTensorFlowModel(modelPath);
34+
var schema = tensorFlowModel.GetModelSchema();
35+
36+
for (int i = 0; i < schema.Count; i++)
37+
{
38+
var name = schema[i].Name;
39+
var type = schema[i].Type;
40+
41+
var metadataType = schema[i].Annotations.Schema.GetColumnOrNull("TensorflowOperatorType")?.Type;
42+
ReadOnlyMemory<char> opType = default;
43+
schema[i].Annotations.GetValue("TensorflowOperatorType", ref opType);
44+
metadataType = schema[i].Annotations.Schema.GetColumnOrNull("TensorflowUpstreamOperators")?.Type;
45+
VBuffer <ReadOnlyMemory<char>> inputOps = default;
46+
if (metadataType != null)
47+
{
48+
schema[i].Annotations.GetValue("TensorflowUpstreamOperators", ref inputOps);
49+
}
50+
51+
string[] inputOpsResult = inputOps.DenseValues()
52+
.Select(input => input.ToString())
53+
.ToArray();
54+
55+
yield return (name, opType.ToString(), type, inputOpsResult);
56+
}
57+
}
2658
}
2759
}

src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ public OutColumn(Vector<float> input, string modelFile)
1919
Input = input;
2020
}
2121

22-
public OutColumn(Vector<float> input, TensorFlowModelInfo tensorFlowModel)
22+
public OutColumn(Vector<float> input, TensorFlowModel tensorFlowModel)
2323
: base(new Reconciler(tensorFlowModel), input)
2424
{
2525
Input = input;
@@ -29,7 +29,7 @@ public OutColumn(Vector<float> input, TensorFlowModelInfo tensorFlowModel)
2929
private sealed class Reconciler : EstimatorReconciler
3030
{
3131
private readonly string _modelFile;
32-
private readonly TensorFlowModelInfo _tensorFlowModel;
32+
private readonly TensorFlowModel _tensorFlowModel;
3333

3434
public Reconciler(string modelFile)
3535
{
@@ -38,7 +38,7 @@ public Reconciler(string modelFile)
3838
_tensorFlowModel = null;
3939
}
4040

41-
public Reconciler(TensorFlowModelInfo tensorFlowModel)
41+
public Reconciler(TensorFlowModel tensorFlowModel)
4242
{
4343
Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel));
4444

@@ -80,7 +80,7 @@ public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, strin
8080
/// Run a TensorFlow model provided through <paramref name="tensorFlowModel"/> on the input column and extract one output column.
8181
/// The inputs and outputs are matched to TensorFlow graph nodes by name.
8282
/// </summary>
83-
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, TensorFlowModelInfo tensorFlowModel)
83+
public static Vector<float> ApplyTensorFlowGraph(this Vector<float> input, TensorFlowModel tensorFlowModel)
8484
{
8585
Contracts.CheckValue(input, nameof(input));
8686
Contracts.CheckValue(tensorFlowModel, nameof(tensorFlowModel));

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

+5-43
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,12 @@ public static class TensorFlowUtils
2121
/// Key to access operator's type (a string) in <see cref="DataViewSchema.Column.Annotations"/>.
2222
/// Its value describes the Tensorflow operator that produces this <see cref="DataViewSchema.Column"/>.
2323
/// </summary>
24-
public const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
24+
internal const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
2525
/// <summary>
2626
/// Key to access upstream operators' names (a string array) in <see cref="DataViewSchema.Column.Annotations"/>.
2727
/// Its value states operators that the associated <see cref="DataViewSchema.Column"/>'s generator depends on.
2828
/// </summary>
29-
public const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
29+
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
3030

3131
internal static DataViewSchema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
3232
{
@@ -94,50 +94,12 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, TFGraph gr
9494
/// </summary>
9595
/// <param name="env">The environment to use.</param>
9696
/// <param name="modelPath">Model to load.</param>
97-
public static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
97+
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
9898
{
9999
var model = LoadTensorFlowModel(env, modelPath);
100100
return GetModelSchema(env, model.Session.Graph);
101101
}
102102

103-
/// <summary>
104-
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
105-
/// iterates over the columns of the <see cref="DataViewSchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
106-
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
107-
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
108-
/// </summary>
109-
/// <param name="env">The environment to use.</param>
110-
/// <param name="modelPath">Model to load.</param>
111-
/// <returns></returns>
112-
public static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(IHostEnvironment env, string modelPath)
113-
{
114-
var schema = GetModelSchema(env, modelPath);
115-
116-
for (int i = 0; i < schema.Count; i++)
117-
{
118-
var name = schema[i].Name;
119-
var type = schema[i].Type;
120-
121-
var metadataType = schema[i].Annotations.Schema.GetColumnOrNull(TensorflowOperatorTypeKind)?.Type;
122-
Contracts.Assert(metadataType != null && metadataType is TextDataViewType);
123-
ReadOnlyMemory<char> opType = default;
124-
schema[i].Annotations.GetValue(TensorflowOperatorTypeKind, ref opType);
125-
metadataType = schema[i].Annotations.Schema.GetColumnOrNull(TensorflowUpstreamOperatorsKind)?.Type;
126-
VBuffer<ReadOnlyMemory<char>> inputOps = default;
127-
if (metadataType != null)
128-
{
129-
Contracts.Assert(metadataType.IsKnownSizeVector() && metadataType.GetItemType() is TextDataViewType);
130-
schema[i].Annotations.GetValue(TensorflowUpstreamOperatorsKind, ref inputOps);
131-
}
132-
133-
string[] inputOpsResult = inputOps.DenseValues()
134-
.Select(input => input.ToString())
135-
.ToArray();
136-
137-
yield return (name, opType.ToString(), type, inputOpsResult);
138-
}
139-
}
140-
141103
internal static PrimitiveDataViewType Tf2MlNetType(TFDataType type)
142104
{
143105
var mlNetType = Tf2MlNetTypeOrNull(type);
@@ -338,10 +300,10 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
338300
/// <param name="env">The environment to use.</param>
339301
/// <param name="modelPath">The model to load.</param>
340302
/// <returns></returns>
341-
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
303+
internal static TensorFlowModel LoadTensorFlowModel(IHostEnvironment env, string modelPath)
342304
{
343305
var session = GetSession(env, modelPath);
344-
return new TensorFlowModelInfo(env, session, modelPath);
306+
return new TensorFlowModel(env, session, modelPath);
345307
}
346308

347309
internal static TFSession GetSession(IHostEnvironment env, string modelPath)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
// See the LICENSE file in the project root for more information.
4+
5+
using Microsoft.Data.DataView;
6+
using Microsoft.ML.Transforms.TensorFlow;
7+
8+
namespace Microsoft.ML.Transforms
9+
{
10+
/// <summary>
11+
/// This class holds the information related to TensorFlow model and session.
12+
/// It provides some convenient methods to query model schema as well as
13+
/// creation of <see cref="TensorFlowEstimator"/> object.
14+
/// </summary>
15+
public sealed class TensorFlowModel
16+
{
17+
internal TFSession Session { get; }
18+
internal string ModelPath { get; }
19+
20+
private readonly IHostEnvironment _env;
21+
22+
/// <summary>
23+
/// Instantiates <see cref="TensorFlowModel"/>.
24+
/// </summary>
25+
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
26+
/// <param name="session">TensorFlow session object.</param>
27+
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
28+
internal TensorFlowModel(IHostEnvironment env, TFSession session, string modelLocation)
29+
{
30+
Session = session;
31+
ModelPath = modelLocation;
32+
_env = env;
33+
}
34+
35+
/// <summary>
36+
/// Get <see cref="DataViewSchema"/> for complete model. Every node in the TensorFlow model will be included in the <see cref="DataViewSchema"/> object.
37+
/// </summary>
38+
public DataViewSchema GetModelSchema()
39+
{
40+
return TensorFlowUtils.GetModelSchema(_env, Session.Graph);
41+
}
42+
43+
/// <summary>
44+
/// Get <see cref="DataViewSchema"/> for only those nodes which are marked "Placeholder" in the TensorFlow model.
45+
/// This method is convenient for exploring the model input(s) in case TensorFlow graph is very large.
46+
/// </summary>
47+
public DataViewSchema GetInputSchema()
48+
{
49+
return TensorFlowUtils.GetModelSchema(_env, Session.Graph, "Placeholder");
50+
}
51+
52+
/// <summary>
53+
/// Scores a dataset using a pre-traiend <a href="https://www.tensorflow.org/">TensorFlow</a> model.
54+
/// </summary>
55+
/// <param name="inputColumnName"> The name of the model input.</param>
56+
/// <param name="outputColumnName">The name of the requested model output.</param>
57+
/// <example>
58+
/// <format type="text/markdown">
59+
/// <![CDATA[
60+
/// [!code-csharp[ScoreTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs)]
61+
/// ]]>
62+
/// </format>
63+
/// </example>
64+
public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName)
65+
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, ModelPath);
66+
67+
/// <summary>
68+
/// Scores a dataset using a pre-traiend TensorFlow model.
69+
/// </summary>
70+
/// <param name="inputColumnNames"> The names of the model inputs.</param>
71+
/// <param name="outputColumnNames">The names of the requested model outputs.</param>
72+
/// <example>
73+
/// <format type="text/markdown">
74+
/// <![CDATA[
75+
/// [!code-csharp[ScoreTensorFlowModel](~/../docs/samples/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlow/ImageClassification.cs)]
76+
/// ]]>
77+
/// </format>
78+
/// </example>
79+
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames)
80+
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, ModelPath);
81+
82+
/// <summary>
83+
/// Retrain the TensorFlow model on new data.
84+
/// The model is not loaded again instead the information contained in <see cref="TensorFlowModel"/> class is reused
85+
/// (c.f. <see cref="TensorFlowModel.ModelPath"/> and <see cref="TensorFlowModel.Session"/>).
86+
/// </summary>
87+
/// <param name="inputColumnNames"> The names of the model inputs.</param>
88+
/// <param name="outputColumnNames">The names of the requested model outputs.</param>
89+
/// <param name="labelColumnName">Name of the label column.</param>
90+
/// <param name="tensorFlowLabel">Name of the node in TensorFlow graph that is used as label during training in TensorFlow.
91+
/// The value of <paramref name="labelColumnName"/> from <see cref="IDataView"/> is fed to this node.</param>
92+
/// <param name="optimizationOperation">The name of the optimization operation in the TensorFlow graph.</param>
93+
/// <param name="epoch">Number of training iterations.</param>
94+
/// <param name="batchSize">Number of samples to use for mini-batch training.</param>
95+
/// <param name="lossOperation">The name of the operation in the TensorFlow graph to compute training loss (Optional).</param>
96+
/// <param name="metricOperation">The name of the operation in the TensorFlow graph to compute performance metric during training (Optional).</param>
97+
/// <param name="learningRateOperation">The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).</param>
98+
/// <param name="learningRate">Learning rate to use during optimization (Optional).</param>
99+
/// <remarks>
100+
/// The support for retraining is experimental.
101+
/// </remarks>
102+
public TensorFlowEstimator RetrainTensorFlowModel(
103+
string[] outputColumnNames,
104+
string[] inputColumnNames,
105+
string labelColumnName,
106+
string tensorFlowLabel,
107+
string optimizationOperation,
108+
int epoch = 10,
109+
int batchSize = 20,
110+
string lossOperation= null,
111+
string metricOperation = null,
112+
string learningRateOperation = null,
113+
float learningRate = 0.01f)
114+
{
115+
var options = new TensorFlowEstimator.Options()
116+
{
117+
ModelLocation = ModelPath,
118+
InputColumns = inputColumnNames,
119+
OutputColumns = outputColumnNames,
120+
LabelColumn = labelColumnName,
121+
TensorFlowLabel = tensorFlowLabel,
122+
OptimizationOperation = optimizationOperation,
123+
LossOperation = lossOperation,
124+
MetricOperation = metricOperation,
125+
Epoch = epoch,
126+
LearningRateOperation = learningRateOperation,
127+
LearningRate = learningRate,
128+
BatchSize = batchSize,
129+
ReTrain = true
130+
};
131+
return new TensorFlowEstimator(_env, options, this);
132+
}
133+
}
134+
}

src/Microsoft.ML.TensorFlow/TensorFlowModelInfo.cs

-60
This file was deleted.

0 commit comments

Comments
 (0)