Skip to content

Internalization of TensorFlowUtils.cs and refactored TensorFlowCatalog. #2672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static void Example()
var idv = mlContext.Data.ReadFromEnumerable(data);

// Create a ML pipeline.
var pipeline = mlContext.Transforms.ScoreTensorFlowModel(
var pipeline = mlContext.Transforms.TensorFlow.ScoreTensorFlowModel(
modelLocation,
new[] { nameof(OutputScores.output) },
new[] { nameof(TensorData.input) });
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ public static void Example()
// Load the TensorFlow model once.
// - Use it for quering the schema for input and output in the model
// - Use it for prediction in the pipeline.
var modelInfo = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation);
var schema = modelInfo.GetModelSchema();
var modelInfo = mlContext.Transforms.TensorFlow.LoadTensorFlowModel(modelLocation);
var schema = mlContext.Transforms.TensorFlow.GetModelSchema(modelInfo);
var featuresType = (VectorType)schema["Features"].Type;
Copy link

@yaeldekel yaeldekel Feb 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Features [](start = 51, length = 8)

Can we add a sample that uses modelInfo.GetInputSchema() to find out what the name of the input node is?

#Resolved

Copy link
Contributor Author

@zeahmed zeahmed Feb 23, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see its being used at a couple of places in the tests e.g.

#Resolved

Console.WriteLine("Name: {0}, Type: {1}, Shape: (-1, {2})", "Features", featuresType.ItemType.RawType, featuresType.Dimensions[0]);
var predictionType = (VectorType)schema["Prediction/Softmax"].Type;
Expand All @@ -72,7 +72,7 @@ public static void Example()
var engine = mlContext.Transforms.Text.TokenizeWords("TokenizedWords", "Sentiment_Text")
.Append(mlContext.Transforms.Conversion.ValueMap(lookupMap, "Words", "Ids", new[] { ("VariableLenghtFeatures", "TokenizedWords") }))
.Append(mlContext.Transforms.CustomMapping(ResizeFeaturesAction, "Resize"))
.Append(mlContext.Transforms.ScoreTensorFlowModel(modelInfo, new[] { "Prediction/Softmax" }, new[] { "Features" }))
.Append(mlContext.Transforms.TensorFlow.ScoreTensorFlowModel(modelInfo, new[] { "Prediction/Softmax" }, new[] { "Features" }))
.Append(mlContext.Transforms.CopyColumns(("Prediction", "Prediction/Softmax")))
.Fit(dataView)
.CreatePredictionEngine<IMDBSentiment, OutputScores>(mlContext);
Expand Down
16 changes: 16 additions & 0 deletions src/Microsoft.ML.Data/Transforms/TransformsCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ public sealed class TransformsCatalog
/// </summary>
public FeatureSelectionTransforms FeatureSelection { get; }

/// <summary>
/// List of operations for using TensorFlow model.
/// </summary>
public TensorFlowTransforms TensorFlow { get; }
Copy link
Contributor

@TomFinley TomFinley Feb 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorFlowTransforms [](start = 15, length = 20)

Please do not do this. Otherwise we will have an empty property unless someone imports the nuget, which is confusing and undesirable.

Please follow instead the pattern that we see in image processing. You'll note that we do not have an empty image processing nuget. Rather they are added to this catalog. Similar with ONNX scoring. You'll note that these both have extensions on this TransformsCatalog catalog. What we don't have are empty properties "Images" and "ONNX" littering our central object that users interact with to instantiate components.

This is defensible since we can take someone directly importing a nuget as a strong signal that they want to actually use those transforms. #Resolved


internal TransformsCatalog(IHostEnvironment env)
{
Contracts.AssertValue(env);
Expand All @@ -47,6 +52,7 @@ internal TransformsCatalog(IHostEnvironment env)
Text = new TextTransforms(this);
Projection = new ProjectionTransforms(this);
FeatureSelection = new FeatureSelectionTransforms(this);
TensorFlow = new TensorFlowTransforms(this);
}

public abstract class SubCatalogBase
Expand Down Expand Up @@ -109,5 +115,15 @@ internal FeatureSelectionTransforms(TransformsCatalog owner) : base(owner)
{
}
}

/// <summary>
/// The catalog of TensorFlow operations.
/// </summary>
public sealed class TensorFlowTransforms : SubCatalogBase
{
internal TensorFlowTransforms(TransformsCatalog owner) : base(owner)
{
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
// See the LICENSE file in the project root for more information.

using System;
using Microsoft.ML.Transforms.TensorFlow;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.DnnAnalyzer
{
Expand All @@ -17,11 +20,40 @@ public static void Main(string[] args)
return;
}

foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(new MLContext(), args[0]))
foreach (var (name, opType, type, inputs) in GetModelNodes(args[0]))
{
var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");
}
}

private static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(string modelPath)
{
var mlContext = new MLContext();
var tensorFlowModel = mlContext.Transforms.TensorFlow.LoadTensorFlowModel(modelPath);
var schema = mlContext.Transforms.TensorFlow.GetModelSchema(tensorFlowModel);

for (int i = 0; i < schema.Count; i++)
{
var name = schema[i].Name;
var type = schema[i].Type;

var metadataType = schema[i].Metadata.Schema.GetColumnOrNull("TensorflowOperatorType")?.Type;
ReadOnlyMemory<char> opType = default;
schema[i].Metadata.GetValue("TensorflowOperatorType", ref opType);
metadataType = schema[i].Metadata.Schema.GetColumnOrNull("TensorflowUpstreamOperators")?.Type;
VBuffer <ReadOnlyMemory<char>> inputOps = default;
if (metadataType != null)
{
schema[i].Metadata.GetValue("TensorflowUpstreamOperators", ref inputOps);
}

string[] inputOpsResult = inputOps.DenseValues()
.Select(input => input.ToString())
.ToArray();

yield return (name, opType.ToString(), type, inputOpsResult);
}
}
}
}
48 changes: 5 additions & 43 deletions src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ public static class TensorFlowUtils
/// Key to access operator's type (a string) in <see cref="DataViewSchema.Column.Metadata"/>.
/// Its value describes the Tensorflow operator that produces this <see cref="DataViewSchema.Column"/>.
/// </summary>
public const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
internal const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
/// <summary>
/// Key to access upstream operators' names (a string array) in <see cref="DataViewSchema.Column.Metadata"/>.
/// Its value states operators that the associated <see cref="DataViewSchema.Column"/>'s generator depends on.
/// </summary>
public const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
internal const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";

internal static DataViewSchema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
{
Expand Down Expand Up @@ -94,50 +94,12 @@ internal static DataViewSchema GetModelSchema(IExceptionContext ectx, TFGraph gr
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="modelPath">Model to load.</param>
public static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
internal static DataViewSchema GetModelSchema(IHostEnvironment env, string modelPath)
{
var model = LoadTensorFlowModel(env, modelPath);
return GetModelSchema(env, model.Session.Graph);
}

/// <summary>
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
/// iterates over the columns of the <see cref="DataViewSchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
/// </summary>
/// <param name="env">The environment to use.</param>
/// <param name="modelPath">Model to load.</param>
/// <returns></returns>
public static IEnumerable<(string, string, DataViewType, string[])> GetModelNodes(IHostEnvironment env, string modelPath)
{
var schema = GetModelSchema(env, modelPath);

for (int i = 0; i < schema.Count; i++)
{
var name = schema[i].Name;
var type = schema[i].Type;

var metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorflowOperatorTypeKind)?.Type;
Contracts.Assert(metadataType != null && metadataType is TextDataViewType);
ReadOnlyMemory<char> opType = default;
schema[i].Metadata.GetValue(TensorflowOperatorTypeKind, ref opType);
metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorflowUpstreamOperatorsKind)?.Type;
VBuffer<ReadOnlyMemory<char>> inputOps = default;
if (metadataType != null)
{
Contracts.Assert(metadataType.IsKnownSizeVector() && metadataType.GetItemType() is TextDataViewType);
schema[i].Metadata.GetValue(TensorflowUpstreamOperatorsKind, ref inputOps);
}

string[] inputOpsResult = inputOps.DenseValues()
.Select(input => input.ToString())
.ToArray();

yield return (name, opType.ToString(), type, inputOpsResult);
}
}

internal static PrimitiveDataViewType Tf2MlNetType(TFDataType type)
{
var mlNetType = Tf2MlNetTypeOrNull(type);
Expand Down Expand Up @@ -338,10 +300,10 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
/// <param name="env">The environment to use.</param>
/// <param name="modelPath">The model to load.</param>
/// <returns></returns>
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
internal static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
{
var session = GetSession(env, modelPath);
return new TensorFlowModelInfo(env, session, modelPath);
return new TensorFlowModelInfo(session, modelPath);
}

internal static TFSession GetSession(IHostEnvironment env, string modelPath)
Expand Down
33 changes: 2 additions & 31 deletions src/Microsoft.ML.TensorFlow/TensorFlowModelInfo.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,50 +11,21 @@ namespace Microsoft.ML.Transforms
/// <summary>
/// This class holds the information related to TensorFlow model and session.
/// It provides a convenient way to query model schema as follows.
/// <list type="bullet">
/// <item>
/// <description>Get complete schema by calling <see cref="GetModelSchema()"/>.</description>
/// </item>
/// <item>
/// <description>Get schema related to model input(s) by calling <see cref="GetInputSchema()"/>.</description>
/// </item>
/// </list>
/// </summary>
public class TensorFlowModelInfo
public sealed class TensorFlowModelInfo
Copy link

@yaeldekel yaeldekel Feb 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorFlowModelInfo [](start = 24, length = 19)

Would it be possible to rename this to TensorFlowModel? #Resolved

{
internal TFSession Session { get; }
public string ModelPath { get; }
Copy link

@yaeldekel yaeldekel Feb 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

public [](start = 8, length = 6)

I think this can also be internal. #Resolved


private readonly IHostEnvironment _env;

/// <summary>
/// Instantiates <see cref="TensorFlowModelInfo"/>.
/// </summary>
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param>
/// <param name="session">TensorFlow session object.</param>
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param>
internal TensorFlowModelInfo(IHostEnvironment env, TFSession session, string modelLocation)
internal TensorFlowModelInfo(TFSession session, string modelLocation)
{
Session = session;
ModelPath = modelLocation;
_env = env;
}

/// <summary>
/// Get <see cref="DataViewSchema"/> for complete model. Every node in the TensorFlow model will be included in the <see cref="DataViewSchema"/> object.
/// </summary>
public DataViewSchema GetModelSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.Graph);
}

/// <summary>
/// Get <see cref="DataViewSchema"/> for only those nodes which are marked "Placeholder" in the TensorFlow model.
/// This method is convenient for exploring the model input(s) in case TensorFlow graph is very large.
/// </summary>
public DataViewSchema GetInputSchema()
{
return TensorFlowUtils.GetModelSchema(_env, Session.Graph, "Placeholder");
}
}
}
46 changes: 40 additions & 6 deletions src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.Generic;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.TensorFlow;

namespace Microsoft.ML
{
Expand All @@ -25,7 +27,7 @@ public static class TensorflowCatalog
/// ]]>
/// </format>
/// </example>
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
string modelLocation,
string outputColumnName,
string inputColumnName)
Expand All @@ -45,7 +47,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
/// ]]>
/// </format>
/// </example>
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
string modelLocation,
string[] outputColumnNames,
string[] inputColumnNames)
Expand All @@ -58,7 +60,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param>
/// <param name="inputColumnName"> The name of the model input.</param>
/// <param name="outputColumnName">The name of the requested model output.</param>
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
TensorFlowModelInfo tensorFlowModel,
string outputColumnName,
string inputColumnName)
Expand All @@ -78,7 +80,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
/// ]]>
/// </format>
/// </example>
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog,
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog,
TensorFlowModelInfo tensorFlowModel,
string[] outputColumnNames,
string[] inputColumnNames)
Expand All @@ -90,7 +92,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
public static TensorFlowEstimator TensorFlow(this TransformsCatalog.TensorFlowTransforms catalog,
TensorFlowEstimator.Options options)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options);

Expand All @@ -100,9 +102,41 @@ public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
/// <param name="catalog">The transform's catalog.</param>
/// <param name="options">The <see cref="TensorFlowEstimator.Options"/> specifying the inputs and the settings of the <see cref="TensorFlowEstimator"/>.</param>
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param>
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog,
public static TensorFlowEstimator TensorFlow(this TransformsCatalog.TensorFlowTransforms catalog,
TensorFlowEstimator.Options options,
TensorFlowModelInfo tensorFlowModel)
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel);

/// <summary>
/// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="DataViewSchema"/>.
/// For every node in the graph that has an output type that is compatible with the types supported by
/// <see cref="TensorFlowEstimator"/>, the output schema contains a column with the name of that node, and the
/// type of its output (including the item type and the shape, if it is known). Every column also contains metadata
/// of kind <see cref="TensorFlowUtils.TensorflowOperatorTypeKind"/>, indicating the operation type of the node, and if that node has inputs in the graph,
/// it contains metadata of kind <see cref="TensorFlowUtils.TensorflowUpstreamOperatorsKind"/>, indicating the names of the input nodes.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model. Please see <see cref="LoadTensorFlowModel(TransformsCatalog.TensorFlowTransforms, string)"/> to know more about loading model into memory.</param>
public static DataViewSchema GetModelSchema(this TransformsCatalog.TensorFlowTransforms catalog, TensorFlowModelInfo tensorFlowModel)
=> TensorFlowUtils.GetModelSchema(CatalogUtils.GetEnvironment(catalog), tensorFlowModel.Session.Graph);

/// <summary>
/// This method retrieves the information about the input graph nodes of a TensorFlow model as an <see cref="DataViewSchema"/>.
/// The nodes with OpType as "Placeholder" are classified as input nodes. This is the convenience method to inspect only input nodes.
/// For retrieving complete information please see <see cref="GetModelSchema(TransformsCatalog.TensorFlowTransforms, TensorFlowModelInfo)"/>
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model. Please see <see cref="LoadTensorFlowModel(TransformsCatalog.TensorFlowTransforms, string)"/> to know more about loading model into memory.</param>
public static DataViewSchema GetInputSchema(this TransformsCatalog.TensorFlowTransforms catalog, TensorFlowModelInfo tensorFlowModel)
Copy link
Contributor

@TomFinley TomFinley Feb 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetInputSchema [](start = 37, length = 14)

This API does not seem like any C# API I've seen. It seems more like a C API with a bunch of static functions that you call with a pointer to a structure, rather than anything like what I'd expect in an actual object oriented property. Why is not everything you've added as an extension method, except possibly the creation of what you call now this TensorFlowModelInfo, just a property or a method of that TensorFlowModelInfo? #Resolved

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tom, I was under the impression that we wanted to have as much of the functionality as possible accessible from MLContext, and expose as few classes as possible, that's why I suggested exposing these methods here. Do you think the LoadTensorFlowModel API should still be an extension method on TransformsCatalog, or should we expose the TensorFlowUtils class for that?


In reply to: 258769868 [](ancestors = 258769868)

Copy link
Contributor

@TomFinley TomFinley Feb 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's how we've been doing things at all. We use MLContext for things like component creation, but once you, to give the most conspicuous example, create an estimator, the creation of a transformer, the getting of information out of the transformer, the usage of the transformer to transform a data view, that does not involve the MLContext.

You create a model. The model is the object, you use objects by calling methods and properties on those objects, just like everything else. Recall the title of #1098 that is I'd say the central MLContext... one MLContext to create them all, not to use them all.

I gave examples above about how we don't structure our APIs as they are structured here, and I think my reasoning is pretty solid. #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @TomFinley, I hope I addressed your concern here and other places. I am changing it resolved for now.


In reply to: 259531678 [](ancestors = 259531678)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking unresolved... I do not see that anything has changed at all. TensorFlowModel has no meaningful methods on it, everything is still being exposed via these extension methods on top of properties.


In reply to: 259545815 [](ancestors = 259545815,259531678)

Copy link
Contributor Author

@zeahmed zeahmed Feb 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @TomFinley for feedback. I have refactored the code according to the your comments.


In reply to: 260027159 [](ancestors = 260027159,259545815,259531678)

=> TensorFlowUtils.GetModelSchema(CatalogUtils.GetEnvironment(catalog), tensorFlowModel.Session.Graph, "Placeholder");

/// <summary>
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlow(TransformsCatalog.TensorFlowTransforms, TensorFlowEstimator.Options, TensorFlowModelInfo)"/>.
/// </summary>
/// <param name="catalog">The transform's catalog.</param>
/// <param name="modelLocation">Location of the TensorFlow model.</param>
public static TensorFlowModelInfo LoadTensorFlowModel(this TransformsCatalog.TensorFlowTransforms catalog, string modelLocation)
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation);
}
}
Loading