@@ -73,27 +73,25 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
73
73
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
74
74
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
75
75
/// </summary>
76
- /// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
77
- /// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
78
- /// format is supported.</param>
79
- public static Schema GetModelSchema ( IExceptionContext ectx , string modelFile )
76
+ /// <param name="env">The environment to use.</param>
77
+ /// <param name="modelPath">Model to load.</param>
78
+ public static Schema GetModelSchema ( IHostEnvironment env , string modelPath )
80
79
{
81
- var bytes = File . ReadAllBytes ( modelFile ) ;
82
- var session = LoadTFSession ( ectx , bytes , modelFile ) ;
83
- return GetModelSchema ( ectx , session . Graph ) ;
80
+ var model = LoadTensorFlowModel ( env , modelPath ) ;
81
+ return GetModelSchema ( env , model . Session . Graph ) ;
84
82
}
85
83
86
84
/// <summary>
87
85
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
88
- /// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext , string)"/>,
86
+ /// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment , string)"/>,
89
87
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
90
88
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
91
89
/// </summary>
92
- /// <param name="modelFile"> </param>
90
+ /// <param name="modelPath">Model to load. </param>
93
91
/// <returns></returns>
94
- public static IEnumerable < ( string , string , ColumnType , string [ ] ) > GetModelNodes ( string modelFile )
92
+ public static IEnumerable < ( string , string , ColumnType , string [ ] ) > GetModelNodes ( string modelPath )
95
93
{
96
- var schema = GetModelSchema ( null , modelFile ) ;
94
+ var schema = GetModelSchema ( new MLContext ( ) , modelPath ) ;
97
95
98
96
for ( int i = 0 ; i < schema . Count ; i ++ )
99
97
{
@@ -308,6 +306,12 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
308
306
}
309
307
}
310
308
309
+ /// <summary>
310
+ /// Load TensorFlow model into memory.
311
+ /// </summary>
312
+ /// <param name="env">The environment to use.</param>
313
+ /// <param name="modelPath">The model to load.</param>
314
+ /// <returns></returns>
311
315
public static TensorFlowModelInfo LoadTensorFlowModel ( IHostEnvironment env , string modelPath )
312
316
{
313
317
var session = GetSession ( env , modelPath ) ;
0 commit comments