Skip to content

Commit 72a4eb6

Browse files
authored
Tensorflow GetModelSchema should support unfrozen models (dotnet#2112)
1 parent a927479 commit 72a4eb6

File tree

2 files changed

+17
-12
lines changed

2 files changed

+17
-12
lines changed

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,27 +75,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
7575
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
7676
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
7777
/// </summary>
78-
/// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
79-
/// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
80-
/// format is supported.</param>
81-
public static Schema GetModelSchema(IExceptionContext ectx, string modelFile)
78+
/// <param name="env">The environment to use.</param>
79+
/// <param name="modelPath">Model to load.</param>
80+
public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
8281
{
83-
var bytes = File.ReadAllBytes(modelFile);
84-
var session = LoadTFSession(ectx, bytes, modelFile);
85-
return GetModelSchema(ectx, session.Graph);
82+
var model = LoadTensorFlowModel(env, modelPath);
83+
return GetModelSchema(env, model.Session.Graph);
8684
}
8785

8886
/// <summary>
8987
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
90-
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext, string)"/>,
88+
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
9189
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
9290
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
9391
/// </summary>
94-
/// <param name="modelFile"></param>
92+
/// <param name="modelPath">Model to load.</param>
9593
/// <returns></returns>
96-
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile)
94+
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath)
9795
{
98-
var schema = GetModelSchema(null, modelFile);
96+
var schema = GetModelSchema(new MLContext(), modelPath);
9997

10098
for (int i = 0; i < schema.Count; i++)
10199
{
@@ -310,6 +308,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
310308
}
311309
}
312310

311+
/// <summary>
312+
/// Load TensorFlow model into memory.
313+
/// </summary>
314+
/// <param name="env">The environment to use.</param>
315+
/// <param name="modelPath">The model to load.</param>
316+
/// <returns></returns>
313317
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
314318
{
315319
var session = GetSession(env, modelPath);

test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -606,8 +606,9 @@ public void TensorFlowTransformCifar()
606606
public void TensorFlowTransformCifarSavedModel()
607607
{
608608
var modelLocation = "cifar_saved_model";
609-
610609
var mlContext = new MLContext(seed: 1, conc: 1);
610+
var loadModelSchema = TensorFlowUtils.GetModelSchema(mlContext, modelLocation);
611+
Assert.Equal(335, loadModelSchema.Count);
611612
var tensorFlowModel = TensorFlowUtils.LoadTensorFlowModel(mlContext, modelLocation);
612613
var schema = tensorFlowModel.GetInputSchema();
613614
Assert.True(schema.TryGetColumnIndex("Input", out int column));

0 commit comments

Comments
 (0)