-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Create API for extracting information about the nodes in a TensorFlow model #862
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
using System.Collections.Generic; | ||
using System.Linq.Expressions; | ||
using Microsoft.ML.Runtime.Internal.Utilities; | ||
using Microsoft.ML.Runtime.Data; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these using statements actually necessary? I'm missing the additions that actually used them. #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out. I added the new API in this file, but ended up moving it to TensorflowUtils.
In reply to: 216493797 [](ancestors = 216493797)
@@ -700,6 +698,24 @@ public override string ToString() | |||
IntPtr len; | |||
return TF_GraphDebugString(Handle, out len); | |||
} | |||
|
|||
[DllImport(NativeBinding.TensorFlowLibrary)] | |||
internal static extern string TF_OperationName(TF_Operation oper); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than directly exposing the TF C-API we should consider bringing back the TF# defined API (TFOperation class) on top of it. See zeahmed@b2a8016#diff-ec7ea5716f3c05f773d3e1507b4f486aL729, zeahmed@b2a8016#diff-ec7ea5716f3c05f773d3e1507b4f486aL748. #Resolved
continue; | ||
|
||
var numInputs = TFGraph.TF_OperationNumInputs(oper); | ||
if (numInputs == 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numInputs [](start = 20, length = 9)
What does numInputs == 0
indicates? The input node does not have any input I believe??? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input has OpType "Placeholder". There are other nodes with numInputs==0, which have OpType "Const". I am not sure what they do but I think we don't want them in our output schema. What do you think?
In reply to: 216508337 [](ancestors = 216508337)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you think if such nodes are not filter? Do you foresee adverse effect of doing this or not doing this?
In reply to: 216758114 [](ancestors = 216758114,216508337)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the nodes with 0 inputs to the schema as well.
In reply to: 216817673 [](ancestors = 216817673,216758114,216508337)
while ((oper = TFGraph.TF_GraphNextOperation(graph.handle, &pos)) != IntPtr.Zero) | ||
{ | ||
var name = TFGraph.TF_OperationName(oper); | ||
var type = TFGraph.TF_OperationOpType(oper); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type [](start = 20, length = 4)
Its not being used anywhere. #Closed
var model_location = "mnist_model/frozen_saved_model.pb"; | ||
using (var env = new TlcEnvironment(seed: 1, conc: 1)) | ||
{ | ||
var schema = TensorFlowUtils.GetModelSchema(env, model_location); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
schema [](start = 20, length = 6)
It would be nice to have schema checked against the actual model information. It seems like mnist_model/frozen_saved_model.pb
is a big model. Matrix multiplication model used in above test may be a good for this test.
It can be implemented as a new test in addition to this one. #Closed
public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile) | ||
{ | ||
var bytes = File.ReadAllBytes(modelFile); | ||
var session = LoadTFSession(ectx, bytes, modelFile); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LoadTFSession [](start = 26, length = 13)
What about models that are not frozen? I assume it will be there once @abgoswam changes are there, right? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I assume that LoadTFSession will take care of the logic to decide which kind of model to load.
In reply to: 216509938 [](ancestors = 216509938)
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
_host = env.Register(nameof(RegistrationName)); | ||
_host.CheckValue(modelBytes, nameof(modelBytes)); | ||
Session = LoadTFSession(modelBytes); | ||
Session = TensorFlowUtils.LoadTFSession(_host, modelBytes, modelFile); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Abhishek is working on brining another way to load session, so I think it's better to extend CheckFileAndRead function and force it to return you Session instead of byte array. So private constructor would just accept session. #Closed
@@ -182,16 +166,16 @@ private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile) | |||
} | |||
|
|||
public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) : | |||
this(env, CheckFileAndRead(env, modelFile), inputs, outputs) | |||
this(env, CheckFileAndRead(env, modelFile), inputs, outputs, modelFile) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
modelFile [](start = 73, length = 9)
is it model file or model args? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it still WIP or it's ok to review it properly? I see people signing off, and WIP status and find this a bit confusing #Resolved |
…nodes, and a console app that displays it
return; | ||
} | ||
|
||
foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
foreach [](start = 16, length = 7)
I can add more arguments to this app, to let the user filter nodes by operation type (for example, sometimes there are lots of "Const" nodes that users might not be interested in if they are just trying to find the name of a certain layer).
Is this valuable? Is having a method that returns this information enough so users can filter programatically, or would user want this as well?
Please review properly, I removed the WIP from the title. In reply to: 421182387 [](ancestors = 421182387) |
<PropertyGroup> | ||
<OutputType>Exe</OutputType> | ||
<TargetFramework>netcoreapp2.1</TargetFramework> | ||
<AssemblyName>DnnAnalyzer</AssemblyName> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should it be part of TensorFlow package? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
{ | ||
if (Utils.Size(args) != 1) | ||
{ | ||
ch.Error("Usage: dotnet DnnAnalyzer.dll <model_location>"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.dll [](start = 55, length = 4)
is dll necessary? and should it be dotnet run or just dotnet works? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"dotnet DnnAnalyzer " didn't work.
"dotnet run DnnAnalyzer.dll " didn't work.
Is there a different syntax I should try?
In reply to: 218189666 [](ancestors = 218189666)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, apparently i need to study dotnet better. dotnet run for projects only,
In reply to: 218535326 [](ancestors = 218535326,218189666)
{ | ||
public static void Main(string[] args) | ||
{ | ||
using (var env = new TlcEnvironment()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TlcEnvironment [](start = 33, length = 14)
isn't it feel weird to create environment only to write something to console?
Why can't you just do Console.Writeline? #Resolved
{ | ||
} | ||
|
||
private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs) | ||
private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string modelFile = null) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
string modelFile = null [](start = 112, length = 23)
not necessary anymore #Closed
var opTypeGetters = new List<MetadataUtils.MetadataGetter<DvText>>(); | ||
var inputOpsGetters = new List<MetadataUtils.MetadataGetter<VBuffer<DvText>>>(); | ||
var inputOpsLengths = new List<int>(); | ||
foreach (var oper in graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oper [](start = 25, length = 4)
nit: If you shorten operation to oper, you can just go even further to "op" option.
#Closed
var inputOpsLengths = new List<int>(); | ||
foreach (var oper in graph) | ||
{ | ||
if (oper.NumOutputs != 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (oper.NumOutputs != 1) [](start = 16, length = 25)
I think this deserve comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this code. I think that if a node has multiple outputs it means that the output it produces is used as input to multiple nodes, but the shape and type of this output will be the same for every node that uses it as input. In this case there is no need to skip it. @zeahmed, is my assumption correct?
In reply to: 218204247 [](ancestors = 218204247)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, tf.Operation can produced multiple outputs with different types and shapes. In case of single output used by multiple nodes, I am not sure if `oper.NumOutputs > 1' in that case. Technically, recurrent layers can produced two outputs (hidden state, output) but I would need to see how the recurrent ops are implemented in graphs.
In reply to: 218551502 [](ancestors = 218551502,218204247)
@@ -25,7 +30,92 @@ public static void Initialize() | |||
ImageAnalytics.Initialize(); | |||
} | |||
|
|||
private static unsafe ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unsafe [](start = 23, length = 6)
just curious, why it's unsafe? #Closed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for pointing it out, it doesn't need to be. In a previous iteration I was directly running the while loop over the graph that is now in the Graph class, so it was unsafe.
In reply to: 218204434 [](ancestors = 218204434)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🕐
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This PR addresses issue #791 .
Please feel free to add feedback or suggestions.