@@ -20,51 +20,49 @@ public static class TensorFlowUtils
20
20
21
21
internal static Schema GetModelSchema ( IExceptionContext ectx , TFGraph graph , string opType = null )
22
22
{
23
- var res = new List < KeyValuePair < string , ColumnType > > ( ) ;
24
- var opTypeGetters = new List < MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > > ( ) ;
25
- var inputOpsGetters = new List < MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > > ( ) ;
26
- var inputOpsLengths = new List < int > ( ) ;
23
+ var schemaBuilder = new SchemaBuilder ( ) ;
27
24
foreach ( var op in graph )
28
25
{
29
26
if ( opType != null && opType != op . OpType )
30
27
continue ;
28
+
31
29
var tfType = op [ 0 ] . OutputType ;
30
+ // Determine element type in Tensorflow tensor. For example, a vector of floats may get NumberType.R4 here.
32
31
var mlType = Tf2MlNetTypeOrNull ( tfType ) ;
33
32
34
33
// If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema.
35
34
// We also cannot output it with a TensorFlowTransform, so we skip it.
36
35
if ( mlType == null )
37
36
continue ;
38
37
39
- var shape = graph . GetTensorShape ( op [ 0 ] ) ;
40
- var shapeArray = shape . ToIntArray ( ) ;
41
-
42
- inputOpsLengths . Add ( op . NumInputs ) ;
43
- MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > inputOpsGetter = null ;
38
+ // Construct the final ML.NET type of a Tensorflow variable.
39
+ var shapeArray = graph . GetTensorShape ( op [ 0 ] ) . ToIntArray ( ) ;
40
+ var columnType = new VectorType ( mlType ) ;
41
+ if ( Utils . Size ( shapeArray ) > 0 && shapeArray . Skip ( 1 ) . All ( x => x > 0 ) )
42
+ columnType = new VectorType ( mlType , shapeArray [ 0 ] > 0 ? shapeArray : shapeArray . Skip ( 1 ) . ToArray ( ) ) ;
43
+
44
+ // There can be at most two metadata fields.
45
+ // 1. The first field always presents. Its value is this operator's ID. For example,
46
+ // if an output is produced by an operator named "A", the value of this field should be "A".
47
+ // 2. The second field stores operators whose outputs are consumed by this operator. In other words,
48
+ // these values are names of some upstream operators which should be evaluated before executing
49
+ // the current operator. It's possible that one operator doesn't need any input, so this field
50
+ // can be missing.
51
+ var metadataBuilder = new MetadataBuilder ( ) ;
52
+ metadataBuilder . Add ( OpType , TextType . Instance , ( ref ReadOnlyMemory < char > value ) => value = op . OpType . AsMemory ( ) ) ;
44
53
if ( op . NumInputs > 0 )
45
- {
46
- var inputOps = new ReadOnlyMemory < char > [ op . NumInputs ] ;
47
- for ( int i = 0 ; i < op . NumInputs ; i ++ )
48
- {
49
- var input = op . GetInput ( i ) ;
50
- inputOps [ i ] = new ReadOnlyMemory < char > ( input . Operation . Name . ToArray ( ) ) ;
51
- }
52
- inputOpsGetter = ( int col , ref VBuffer < ReadOnlyMemory < char > > dst ) =>
53
- dst = new VBuffer < ReadOnlyMemory < char > > ( op . NumInputs , inputOps ) ;
54
- }
55
- inputOpsGetters . Add ( inputOpsGetter ) ;
56
-
57
- MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > opTypeGetter =
58
- ( int col , ref ReadOnlyMemory < char > dst ) => dst = new ReadOnlyMemory < char > ( op . OpType . ToArray ( ) ) ;
59
- opTypeGetters . Add ( opTypeGetter ) ;
60
-
61
- var columnType = Utils . Size ( shapeArray ) == 1 && shapeArray [ 0 ] <= 0 ? new VectorType ( mlType ) :
62
- Utils . Size ( shapeArray ) > 0 && shapeArray . Skip ( 1 ) . All ( x => x > 0 ) ?
63
- new VectorType ( mlType , shapeArray [ 0 ] > 0 ? shapeArray : shapeArray . Skip ( 1 ) . ToArray ( ) )
64
- : new VectorType ( mlType ) ;
65
- res . Add ( new KeyValuePair < string , ColumnType > ( op . Name , columnType ) ) ;
54
+ metadataBuilder . Add ( InputOps , new VectorType ( TextType . Instance , op . NumInputs ) ,
55
+ ( ref VBuffer < ReadOnlyMemory < char > > value ) =>
56
+ {
57
+ var bufferEditor = VBufferEditor . Create ( ref value , op . NumInputs ) ;
58
+ for ( int i = 0 ; i < op . NumInputs ; ++ i )
59
+ bufferEditor . Values [ i ] = op . GetInput ( i ) . Operation . Name . AsMemory ( ) ;
60
+ value = bufferEditor . Commit ( ) ;
61
+ } ) ;
62
+
63
+ schemaBuilder . AddColumn ( op . Name , columnType , metadataBuilder . GetMetadata ( ) ) ;
66
64
}
67
- return Schema . Create ( new TensorFlowSchema ( ectx , res . ToArray ( ) , opTypeGetters . ToArray ( ) , inputOpsGetters . ToArray ( ) , inputOpsLengths . ToArray ( ) ) ) ;
65
+ return schemaBuilder . GetSchema ( ) ;
68
66
}
69
67
70
68
/// <summary>
@@ -75,25 +73,27 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
75
73
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
76
74
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
77
75
/// </summary>
78
- /// <param name="env">The environment to use.</param>
79
- /// <param name="modelPath">Model to load.</param>
80
- public static Schema GetModelSchema ( IHostEnvironment env , string modelPath )
76
+ /// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
77
+ /// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
78
+ /// format is supported.</param>
79
+ public static Schema GetModelSchema ( IExceptionContext ectx , string modelFile )
81
80
{
82
- var model = LoadTensorFlowModel ( env , modelPath ) ;
83
- return GetModelSchema ( env , model . Session . Graph ) ;
81
+ var bytes = File . ReadAllBytes ( modelFile ) ;
82
+ var session = LoadTFSession ( ectx , bytes , modelFile ) ;
83
+ return GetModelSchema ( ectx , session . Graph ) ;
84
84
}
85
85
86
86
/// <summary>
87
87
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
88
- /// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment , string)"/>,
88
+ /// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext , string)"/>,
89
89
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
90
90
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
91
91
/// </summary>
92
- /// <param name="modelPath">Model to load. </param>
92
+ /// <param name="modelFile"> </param>
93
93
/// <returns></returns>
94
- public static IEnumerable < ( string , string , ColumnType , string [ ] ) > GetModelNodes ( string modelPath )
94
+ public static IEnumerable < ( string , string , ColumnType , string [ ] ) > GetModelNodes ( string modelFile )
95
95
{
96
- var schema = GetModelSchema ( new MLContext ( ) , modelPath ) ;
96
+ var schema = GetModelSchema ( null , modelFile ) ;
97
97
98
98
for ( int i = 0 ; i < schema . Count ; i ++ )
99
99
{
@@ -308,12 +308,6 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
308
308
}
309
309
}
310
310
311
- /// <summary>
312
- /// Load TensorFlow model into memory.
313
- /// </summary>
314
- /// <param name="env">The environment to use.</param>
315
- /// <param name="modelPath">The model to load.</param>
316
- /// <returns></returns>
317
311
public static TensorFlowModelInfo LoadTensorFlowModel ( IHostEnvironment env , string modelPath )
318
312
{
319
313
var session = GetSession ( env , modelPath ) ;
@@ -358,55 +352,5 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
358
352
return false ;
359
353
}
360
354
}
361
-
362
- private sealed class TensorFlowSchema : SimpleSchemaBase
363
- {
364
- private readonly MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > [ ] _opTypeGetters ;
365
- private readonly MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > [ ] _inputOpsGetters ;
366
- private readonly int [ ] _inputOpsLengths ;
367
-
368
- public TensorFlowSchema ( IExceptionContext ectx , KeyValuePair < string , ColumnType > [ ] columns ,
369
- MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > [ ] opTypeGetters ,
370
- MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > [ ] inputOpsGetters , int [ ] inputOpsLengths )
371
- : base ( ectx , columns )
372
- {
373
- ectx . CheckParam ( Utils . Size ( opTypeGetters ) == ColumnCount , nameof ( opTypeGetters ) ) ;
374
- ectx . CheckParam ( Utils . Size ( inputOpsGetters ) == ColumnCount , nameof ( inputOpsGetters ) ) ;
375
- ectx . CheckParam ( Utils . Size ( inputOpsLengths ) == ColumnCount , nameof ( inputOpsLengths ) ) ;
376
-
377
- _opTypeGetters = opTypeGetters ;
378
- _inputOpsGetters = inputOpsGetters ;
379
- _inputOpsLengths = inputOpsLengths ;
380
- }
381
-
382
- protected override void GetMetadataCore < TValue > ( string kind , int col , ref TValue value )
383
- {
384
- Ectx . Assert ( 0 <= col && col < ColumnCount ) ;
385
- if ( kind == OpType )
386
- _opTypeGetters [ col ] . Marshal ( col , ref value ) ;
387
- else if ( kind == InputOps && _inputOpsGetters [ col ] != null )
388
- _inputOpsGetters [ col ] . Marshal ( col , ref value ) ;
389
- else
390
- throw Ectx . ExceptGetMetadata ( ) ;
391
- }
392
-
393
- protected override ColumnType GetMetadataTypeOrNullCore ( string kind , int col )
394
- {
395
- Ectx . Assert ( 0 <= col && col < ColumnCount ) ;
396
- if ( kind == OpType )
397
- return TextType . Instance ;
398
- if ( kind == InputOps && _inputOpsGetters [ col ] != null )
399
- return new VectorType ( TextType . Instance , _inputOpsLengths [ col ] ) ;
400
- return null ;
401
- }
402
-
403
- protected override IEnumerable < KeyValuePair < string , ColumnType > > GetMetadataTypesCore ( int col )
404
- {
405
- Ectx . Assert ( 0 <= col && col < ColumnCount ) ;
406
- yield return new KeyValuePair < string , ColumnType > ( OpType , TextType . Instance ) ;
407
- if ( _inputOpsGetters [ col ] != null )
408
- yield return new KeyValuePair < string , ColumnType > ( InputOps , new VectorType ( TextType . Instance , _inputOpsLengths [ col ] ) ) ;
409
- }
410
- }
411
355
}
412
356
}
0 commit comments