-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Internalization of TensorFlowUtils.cs and refactored TensorFlowCatalog. #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e5eef19
b170f52
ee9b7ae
6cc3f1c
1abb719
7b4c08c
fc188cd
963b9cd
a78ba89
ffd534e
d1c0dd8
e742885
7cd88ed
509bb12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Transforms.TensorFlow; | ||
|
||
namespace Microsoft.ML.Transforms | ||
{ | ||
/// <summary> | ||
/// This class holds the information related to TensorFlow model and session. | ||
/// It provides some convenient methods to query model schema as well as | ||
/// creation of <see cref="TensorFlowEstimator"/> object. | ||
/// </summary> | ||
public sealed class TensorFlowModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think possibly maybe there's some misunderstanding. Just to be explicit, I expect to see methods used to query the model to be on the model, that is, they will be just methods and properties of the model. Creatingtransformers, querying the schema, or whatever, will be here. The implication is that nearly everything in |
||
{ | ||
internal TFSession Session { get; } | ||
internal string ModelPath { get; } | ||
|
||
private readonly IHostEnvironment _env; | ||
|
||
/// <summary> | ||
/// Instantiates <see cref="TensorFlowModel"/>. | ||
/// </summary> | ||
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param> | ||
/// <param name="session">TensorFlow session object.</param> | ||
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param> | ||
internal TensorFlowModel(IHostEnvironment env, TFSession session, string modelLocation) | ||
{ | ||
Session = session; | ||
ModelPath = modelLocation; | ||
_env = env; | ||
} | ||
|
||
/// <summary> | ||
/// Get <see cref="DataViewSchema"/> for complete model. Every node in the TensorFlow model will be included in the <see cref="DataViewSchema"/> object. | ||
/// </summary> | ||
public DataViewSchema GetModelSchema() | ||
{ | ||
return TensorFlowUtils.GetModelSchema(_env, Session.Graph); | ||
} | ||
|
||
/// <summary> | ||
/// Get <see cref="DataViewSchema"/> for only those nodes which are marked "Placeholder" in the TensorFlow model. | ||
/// This method is convenient for exploring the model input(s) in case TensorFlow graph is very large. | ||
/// </summary> | ||
public DataViewSchema GetInputSchema() | ||
{ | ||
return TensorFlowUtils.GetModelSchema(_env, Session.Graph, "Placeholder"); | ||
} | ||
|
||
/// <summary> | ||
/// Scores a dataset using a pre-traiend <a href="https://www.tensorflow.org/">TensorFlow</a> model. | ||
/// </summary> | ||
/// <param name="inputColumnName"> The name of the model input.</param> | ||
/// <param name="outputColumnName">The name of the requested model output.</param> | ||
/// <example> | ||
/// <format type="text/markdown"> | ||
/// <] | ||
/// ]]> | ||
/// </format> | ||
/// </example> | ||
public TensorFlowEstimator ScoreTensorFlowModel(string outputColumnName, string inputColumnName) | ||
=> new TensorFlowEstimator(_env, new[] { outputColumnName }, new[] { inputColumnName }, ModelPath); | ||
|
||
/// <summary> | ||
/// Scores a dataset using a pre-traiend TensorFlow model. | ||
/// </summary> | ||
/// <param name="inputColumnNames"> The names of the model inputs.</param> | ||
/// <param name="outputColumnNames">The names of the requested model outputs.</param> | ||
/// <example> | ||
/// <format type="text/markdown"> | ||
/// <] | ||
/// ]]> | ||
/// </format> | ||
/// </example> | ||
public TensorFlowEstimator ScoreTensorFlowModel(string[] outputColumnNames, string[] inputColumnNames) | ||
=> new TensorFlowEstimator(_env, outputColumnNames, inputColumnNames, ModelPath); | ||
|
||
/// <summary> | ||
/// Retrain the TensorFlow model on new data. | ||
/// The model is not loaded again instead the information contained in <see cref="TensorFlowModel"/> class is reused | ||
/// (c.f. <see cref="TensorFlowModel.ModelPath"/> and <see cref="TensorFlowModel.Session"/>). | ||
/// </summary> | ||
/// <param name="inputColumnNames"> The names of the model inputs.</param> | ||
/// <param name="outputColumnNames">The names of the requested model outputs.</param> | ||
/// <param name="labelColumnName">Name of the label column.</param> | ||
/// <param name="tensorFlowLabel">Name of the node in TensorFlow graph that is used as label during training in TensorFlow. | ||
/// The value of <paramref name="labelColumnName"/> from <see cref="IDataView"/> is fed to this node.</param> | ||
/// <param name="optimizationOperation">The name of the optimization operation in the TensorFlow graph.</param> | ||
/// <param name="epoch">Number of training iterations.</param> | ||
/// <param name="batchSize">Number of samples to use for mini-batch training.</param> | ||
/// <param name="lossOperation">The name of the operation in the TensorFlow graph to compute training loss (Optional).</param> | ||
/// <param name="metricOperation">The name of the operation in the TensorFlow graph to compute performance metric during training (Optional).</param> | ||
/// <param name="learningRateOperation">The name of the operation in the TensorFlow graph which sets optimizer learning rate (Optional).</param> | ||
/// <param name="learningRate">Learning rate to use during optimization (Optional).</param> | ||
/// <remarks> | ||
/// The support for retraining is experimental. | ||
/// </remarks> | ||
public TensorFlowEstimator RetrainTensorFlowModel( | ||
string[] outputColumnNames, | ||
string[] inputColumnNames, | ||
string labelColumnName, | ||
string tensorFlowLabel, | ||
string optimizationOperation, | ||
int epoch = 10, | ||
int batchSize = 20, | ||
string lossOperation= null, | ||
string metricOperation = null, | ||
string learningRateOperation = null, | ||
float learningRate = 0.01f) | ||
{ | ||
var options = new TensorFlowEstimator.Options() | ||
{ | ||
ModelLocation = ModelPath, | ||
InputColumns = inputColumnNames, | ||
OutputColumns = outputColumnNames, | ||
LabelColumn = labelColumnName, | ||
TensorFlowLabel = tensorFlowLabel, | ||
OptimizationOperation = optimizationOperation, | ||
LossOperation = lossOperation, | ||
MetricOperation = metricOperation, | ||
Epoch = epoch, | ||
LearningRateOperation = learningRateOperation, | ||
LearningRate = learningRate, | ||
BatchSize = batchSize, | ||
ReTrain = true | ||
}; | ||
return new TensorFlowEstimator(_env, options, this); | ||
} | ||
} | ||
} |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a sample that uses
modelInfo.GetInputSchema()
to find out what the name of the input node is?#Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see its being used at a couple of places in the tests e.g.
machinelearning/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Line 894 in eb959c3
machinelearning/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Line 849 in eb959c3