@@ -75,27 +75,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
75
75
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
76
76
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
77
77
/// </summary>
78
- /// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
79
- /// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
80
- /// format is supported.</param>
81
- public static Schema GetModelSchema ( IExceptionContext ectx , string modelFile )
78
+ /// <param name="env">The environment to use.</param>
79
+ /// <param name="modelPath">Model to load.</param>
80
+ public static Schema GetModelSchema ( IHostEnvironment env , string modelPath )
82
81
{
83
- var bytes = File . ReadAllBytes ( modelFile ) ;
84
- var session = LoadTFSession ( ectx , bytes , modelFile ) ;
85
- return GetModelSchema ( ectx , session . Graph ) ;
82
+ var model = LoadTensorFlowModel ( env , modelPath ) ;
83
+ return GetModelSchema ( env , model . Session . Graph ) ;
86
84
}
87
85
88
86
/// <summary>
89
87
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
90
- /// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext , string)"/>,
88
+ /// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment , string)"/>,
91
89
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
92
90
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
93
91
/// </summary>
94
- /// <param name="modelFile"> </param>
92
+ /// <param name="modelPath">Model to load. </param>
95
93
/// <returns></returns>
96
- public static IEnumerable < ( string , string , ColumnType , string [ ] ) > GetModelNodes ( string modelFile )
94
+ public static IEnumerable < ( string , string , ColumnType , string [ ] ) > GetModelNodes ( string modelPath )
97
95
{
98
- var schema = GetModelSchema ( null , modelFile ) ;
96
+ var schema = GetModelSchema ( new MLContext ( ) , modelPath ) ;
99
97
100
98
for ( int i = 0 ; i < schema . Count ; i ++ )
101
99
{
@@ -310,6 +308,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
310
308
}
311
309
}
312
310
311
+ /// <summary>
312
+ /// Load TensorFlow model into memory.
313
+ /// </summary>
314
+ /// <param name="env">The environment to use.</param>
315
+ /// <param name="modelPath">The model to load.</param>
316
+ /// <returns></returns>
313
317
public static TensorFlowModelInfo LoadTensorFlowModel ( IHostEnvironment env , string modelPath )
314
318
{
315
319
var session = GetSession ( env , modelPath ) ;
0 commit comments