-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Internalization of TensorFlowUtils.cs and refactored TensorFlowCatalog. #2672
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
e5eef19
b170f52
ee9b7ae
6cc3f1c
1abb719
7b4c08c
fc188cd
963b9cd
a78ba89
ffd534e
d1c0dd8
e742885
7cd88ed
509bb12
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,6 @@ | |
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Transforms.TensorFlow; | ||
|
||
namespace Microsoft.ML.Transforms | ||
|
@@ -20,20 +19,20 @@ namespace Microsoft.ML.Transforms | |
/// </item> | ||
/// </list> | ||
/// </summary> | ||
public class TensorFlowModelInfo | ||
public sealed class TensorFlowModel | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think possibly maybe there's some misunderstanding. Just to be explicit, I expect to see methods used to query the model to be on the model, that is, they will be just methods and properties of the model. Creatingtransformers, querying the schema, or whatever, will be here. The implication is that nearly everything in |
||
{ | ||
internal TFSession Session { get; } | ||
public string ModelPath { get; } | ||
internal string ModelPath { get; } | ||
|
||
private readonly IHostEnvironment _env; | ||
|
||
/// <summary> | ||
/// Instantiates <see cref="TensorFlowModelInfo"/>. | ||
/// Instantiates <see cref="TensorFlowModel"/>. | ||
/// </summary> | ||
/// <param name="env">An <see cref="IHostEnvironment"/> object.</param> | ||
/// <param name="session">TensorFlow session object.</param> | ||
/// <param name="modelLocation">Location of the model from where <paramref name="session"/> was loaded.</param> | ||
internal TensorFlowModelInfo(IHostEnvironment env, TFSession session, string modelLocation) | ||
internal TensorFlowModel(IHostEnvironment env, TFSession session, string modelLocation) | ||
{ | ||
Session = session; | ||
ModelPath = modelLocation; | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
using Microsoft.Data.DataView; | ||
using Microsoft.ML.Data; | ||
using Microsoft.ML.Transforms; | ||
using Microsoft.ML.Transforms.TensorFlow; | ||
|
||
namespace Microsoft.ML | ||
{ | ||
|
@@ -59,7 +60,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca | |
/// <param name="inputColumnName"> The name of the model input.</param> | ||
/// <param name="outputColumnName">The name of the requested model output.</param> | ||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This should be an operation on the model. #Resolved |
||
TensorFlowModelInfo tensorFlowModel, | ||
TensorFlowModel tensorFlowModel, | ||
string outputColumnName, | ||
string inputColumnName) | ||
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), new[] { outputColumnName }, new[] { inputColumnName }, tensorFlowModel); | ||
|
@@ -79,7 +80,7 @@ public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog ca | |
/// </format> | ||
/// </example> | ||
public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Likewise on the model.. #Resolved |
||
TensorFlowModelInfo tensorFlowModel, | ||
TensorFlowModel tensorFlowModel, | ||
string[] outputColumnNames, | ||
string[] inputColumnNames) | ||
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, tensorFlowModel); | ||
|
@@ -102,7 +103,16 @@ public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, | |
/// <param name="tensorFlowModel">The pre-loaded TensorFlow model.</param> | ||
public static TensorFlowEstimator TensorFlow(this TransformsCatalog catalog, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So this I find pretty confusing. Do we create estimators via this method, or do we work through the model object? #Resolved |
||
TensorFlowEstimator.Options options, | ||
TensorFlowModelInfo tensorFlowModel) | ||
TensorFlowModel tensorFlowModel) | ||
=> new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), options, tensorFlowModel); | ||
|
||
/// <summary> | ||
/// Load TensorFlow model into memory. This is the convenience method that allows the model to be loaded once and subsequently use it for querying schema and creation of | ||
/// <see cref="TensorFlowEstimator"/> using <see cref="TensorFlow(TransformsCatalog, TensorFlowEstimator.Options, TensorFlowModel)"/>. | ||
/// </summary> | ||
/// <param name="catalog">The transform's catalog.</param> | ||
/// <param name="modelLocation">Location of the TensorFlow model.</param> | ||
public static TensorFlowModel LoadTensorFlowModel(this ModelOperationsCatalog catalog, string modelLocation) | ||
=> TensorFlowUtils.LoadTensorFlowModel(CatalogUtils.GetEnvironment(catalog), modelLocation); | ||
} | ||
} |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add a sample that uses
modelInfo.GetInputSchema()
to find out what the name of the input node is?#Resolved
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see its being used at a couple of places in the tests e.g.
machinelearning/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Line 894 in eb959c3
machinelearning/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
Line 849 in eb959c3