Skip to content

Commit 2d0b37b

Browse files
Ivanidzo4kawschin
authored andcommitted
Tensorflow GetModelSchema should support unfrozen models (dotnet#2112)
1 parent 96b2995 commit 2d0b37b

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

+15-11
Original file line numberDiff line numberDiff line change
@@ -73,27 +73,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
7373
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
7474
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
7575
/// </summary>
76-
/// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
77-
/// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
78-
/// format is supported.</param>
79-
public static Schema GetModelSchema(IExceptionContext ectx, string modelFile)
76+
/// <param name="env">The environment to use.</param>
77+
/// <param name="modelPath">Model to load.</param>
78+
public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
8079
{
81-
var bytes = File.ReadAllBytes(modelFile);
82-
var session = LoadTFSession(ectx, bytes, modelFile);
83-
return GetModelSchema(ectx, session.Graph);
80+
var model = LoadTensorFlowModel(env, modelPath);
81+
return GetModelSchema(env, model.Session.Graph);
8482
}
8583

8684
/// <summary>
8785
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
88-
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext, string)"/>,
86+
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
8987
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
9088
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
9189
/// </summary>
92-
/// <param name="modelFile"></param>
90+
/// <param name="modelPath">Model to load.</param>
9391
/// <returns></returns>
94-
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile)
92+
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath)
9593
{
96-
var schema = GetModelSchema(null, modelFile);
94+
var schema = GetModelSchema(new MLContext(), modelPath);
9795

9896
for (int i = 0; i < schema.Count; i++)
9997
{
@@ -308,6 +306,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
308306
}
309307
}
310308

309+
/// <summary>
310+
/// Load TensorFlow model into memory.
311+
/// </summary>
312+
/// <param name="env">The environment to use.</param>
313+
/// <param name="modelPath">The model to load.</param>
314+
/// <returns></returns>
311315
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
312316
{
313317
var session = GetSession(env, modelPath);

0 commit comments

Comments
 (0)