@@ -15,65 +15,77 @@ namespace Microsoft.ML.Transforms.TensorFlow
15
15
{
16
16
public static class TensorFlowUtils
17
17
{
18
- public const string OpType = "OpType" ;
19
- public const string InputOps = "InputOps" ;
18
+ /// <summary>
19
+ /// Key to access operator's type (a string) in <see cref="Schema.Column.Metadata"/>.
20
+ /// Its value describes the Tensorflow operator that produces this <see cref="Schema.Column"/>.
21
+ /// </summary>
22
+ public const string TensorflowOperatorTypeKind = "TensorflowOperatorType" ;
23
+ /// <summary>
24
+ /// Key to access upstream operators' names (a string array) in <see cref="Schema.Column.Metadata"/>.
25
+ /// Its value states operators that the associated <see cref="Schema.Column"/>'s generator depends on.
26
+ /// </summary>
27
+ public const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators" ;
20
28
21
29
internal static Schema GetModelSchema ( IExceptionContext ectx , TFGraph graph , string opType = null )
22
30
{
23
- var res = new List < KeyValuePair < string , ColumnType > > ( ) ;
24
- var opTypeGetters = new List < MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > > ( ) ;
25
- var inputOpsGetters = new List < MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > > ( ) ;
26
- var inputOpsLengths = new List < int > ( ) ;
31
+ var schemaBuilder = new SchemaBuilder ( ) ;
27
32
foreach ( var op in graph )
28
33
{
29
34
if ( opType != null && opType != op . OpType )
30
35
continue ;
36
+
31
37
var tfType = op [ 0 ] . OutputType ;
38
+ // Determine element type in Tensorflow tensor. For example, a vector of floats may get NumberType.R4 here.
32
39
var mlType = Tf2MlNetTypeOrNull ( tfType ) ;
33
40
34
- // If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema .
41
+ // If the type is not supported in ML.NET then we cannot represent it as a column in an Schema .
35
42
// We also cannot output it with a TensorFlowTransform, so we skip it.
36
43
if ( mlType == null )
37
44
continue ;
38
45
39
- var shape = graph . GetTensorShape ( op [ 0 ] ) ;
40
- var shapeArray = shape . ToIntArray ( ) ;
41
-
42
- inputOpsLengths . Add ( op . NumInputs ) ;
43
- MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > inputOpsGetter = null ;
46
+ // Construct the final ML.NET type of a Tensorflow variable.
47
+ var tensorShape = graph . GetTensorShape ( op [ 0 ] ) . ToIntArray ( ) ;
48
+ var columnType = new VectorType ( mlType ) ;
49
+ if ( ! ( Utils . Size ( tensorShape ) == 1 && tensorShape [ 0 ] <= 0 ) &&
50
+ ( Utils . Size ( tensorShape ) > 0 && tensorShape . Skip ( 1 ) . All ( x => x > 0 ) ) )
51
+ columnType = new VectorType ( mlType , tensorShape [ 0 ] > 0 ? tensorShape : tensorShape . Skip ( 1 ) . ToArray ( ) ) ;
52
+
53
+ // There can be at most two metadata fields.
54
+ // 1. The first field always presents. Its value is this operator's type. For example,
55
+ // if an output is produced by an "Softmax" operator, the value of this field should be "Softmax".
56
+ // 2. The second field stores operators whose outputs are consumed by this operator. In other words,
57
+ // these values are names of some upstream operators which should be evaluated before executing
58
+ // the current operator. It's possible that one operator doesn't need any input, so this field
59
+ // can be missing.
60
+ var metadataBuilder = new MetadataBuilder ( ) ;
61
+ // Create the first metadata field.
62
+ metadataBuilder . Add ( TensorflowOperatorTypeKind , TextType . Instance , ( ref ReadOnlyMemory < char > value ) => value = op . OpType . AsMemory ( ) ) ;
44
63
if ( op . NumInputs > 0 )
45
64
{
46
- var inputOps = new ReadOnlyMemory < char > [ op . NumInputs ] ;
47
- for ( int i = 0 ; i < op . NumInputs ; i ++ )
48
- {
49
- var input = op . GetInput ( i ) ;
50
- inputOps [ i ] = new ReadOnlyMemory < char > ( input . Operation . Name . ToArray ( ) ) ;
51
- }
52
- inputOpsGetter = ( int col , ref VBuffer < ReadOnlyMemory < char > > dst ) =>
53
- dst = new VBuffer < ReadOnlyMemory < char > > ( op . NumInputs , inputOps ) ;
65
+ // Put upstream operators' names to an array (type: VBuffer) of string (type: ReadOnlyMemory<char>).
66
+ VBuffer < ReadOnlyMemory < char > > upstreamOperatorNames = default ;
67
+ var bufferEditor = VBufferEditor . Create ( ref upstreamOperatorNames , op . NumInputs ) ;
68
+ for ( int i = 0 ; i < op . NumInputs ; ++ i )
69
+ bufferEditor . Values [ i ] = op . GetInput ( i ) . Operation . Name . AsMemory ( ) ;
70
+ upstreamOperatorNames = bufferEditor . Commit ( ) ; // Used in metadata's getter.
71
+
72
+ // Create the second metadata field.
73
+ metadataBuilder . Add ( TensorflowUpstreamOperatorsKind , new VectorType ( TextType . Instance , op . NumInputs ) ,
74
+ ( ref VBuffer < ReadOnlyMemory < char > > value ) => { upstreamOperatorNames . CopyTo ( ref value ) ; } ) ;
54
75
}
55
- inputOpsGetters . Add ( inputOpsGetter ) ;
56
-
57
- MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > opTypeGetter =
58
- ( int col , ref ReadOnlyMemory < char > dst ) => dst = new ReadOnlyMemory < char > ( op . OpType . ToArray ( ) ) ;
59
- opTypeGetters . Add ( opTypeGetter ) ;
60
76
61
- var columnType = Utils . Size ( shapeArray ) == 1 && shapeArray [ 0 ] <= 0 ? new VectorType ( mlType ) :
62
- Utils . Size ( shapeArray ) > 0 && shapeArray . Skip ( 1 ) . All ( x => x > 0 ) ?
63
- new VectorType ( mlType , shapeArray [ 0 ] > 0 ? shapeArray : shapeArray . Skip ( 1 ) . ToArray ( ) )
64
- : new VectorType ( mlType ) ;
65
- res . Add ( new KeyValuePair < string , ColumnType > ( op . Name , columnType ) ) ;
77
+ schemaBuilder . AddColumn ( op . Name , columnType , metadataBuilder . GetMetadata ( ) ) ;
66
78
}
67
- return Schema . Create ( new TensorFlowSchema ( ectx , res . ToArray ( ) , opTypeGetters . ToArray ( ) , inputOpsGetters . ToArray ( ) , inputOpsLengths . ToArray ( ) ) ) ;
79
+ return schemaBuilder . GetSchema ( ) ;
68
80
}
69
81
70
82
/// <summary>
71
- /// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="ISchema "/>.
83
+ /// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="Schema "/>.
72
84
/// For every node in the graph that has an output type that is compatible with the types supported by
73
85
/// <see cref="TensorFlowTransformer"/>, the output schema contains a column with the name of that node, and the
74
86
/// type of its output (including the item type and the shape, if it is known). Every column also contains metadata
75
- /// of kind <see cref="OpType "/>, indicating the operation type of the node, and if that node has inputs in the graph,
76
- /// it contains metadata of kind <see cref="InputOps "/>, indicating the names of the input nodes.
87
+ /// of kind <see cref="TensorflowOperatorTypeKind "/>, indicating the operation type of the node, and if that node has inputs in the graph,
88
+ /// it contains metadata of kind <see cref="TensorflowUpstreamOperatorsKind "/>, indicating the names of the input nodes.
77
89
/// </summary>
78
90
/// <param name="env">The environment to use.</param>
79
91
/// <param name="modelPath">Model to load.</param>
@@ -85,7 +97,7 @@ public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
85
97
86
98
/// <summary>
87
99
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
88
- /// iterates over the columns of the <see cref="ISchema "/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
100
+ /// iterates over the columns of the <see cref="Schema "/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
89
101
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
90
102
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
91
103
/// </summary>
@@ -100,16 +112,16 @@ public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
100
112
var name = schema [ i ] . Name ;
101
113
var type = schema [ i ] . Type ;
102
114
103
- var metadataType = schema [ i ] . Metadata . Schema . GetColumnOrNull ( TensorFlowUtils . OpType ) ? . Type ;
115
+ var metadataType = schema [ i ] . Metadata . Schema . GetColumnOrNull ( TensorflowOperatorTypeKind ) ? . Type ;
104
116
Contracts . Assert ( metadataType != null && metadataType is TextType ) ;
105
117
ReadOnlyMemory < char > opType = default ;
106
- schema [ i ] . Metadata . GetValue ( TensorFlowUtils . OpType , ref opType ) ;
107
- metadataType = schema [ i ] . Metadata . Schema . GetColumnOrNull ( TensorFlowUtils . InputOps ) ? . Type ;
118
+ schema [ i ] . Metadata . GetValue ( TensorflowOperatorTypeKind , ref opType ) ;
119
+ metadataType = schema [ i ] . Metadata . Schema . GetColumnOrNull ( TensorflowUpstreamOperatorsKind ) ? . Type ;
108
120
VBuffer < ReadOnlyMemory < char > > inputOps = default ;
109
121
if ( metadataType != null )
110
122
{
111
123
Contracts . Assert ( metadataType . IsKnownSizeVector && metadataType . ItemType is TextType ) ;
112
- schema [ i ] . Metadata . GetValue ( TensorFlowUtils . InputOps , ref inputOps ) ;
124
+ schema [ i ] . Metadata . GetValue ( TensorflowUpstreamOperatorsKind , ref inputOps ) ;
113
125
}
114
126
115
127
string [ ] inputOpsResult = inputOps . DenseValues ( )
@@ -358,55 +370,5 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
358
370
return false ;
359
371
}
360
372
}
361
-
362
- private sealed class TensorFlowSchema : SimpleSchemaBase
363
- {
364
- private readonly MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > [ ] _opTypeGetters ;
365
- private readonly MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > [ ] _inputOpsGetters ;
366
- private readonly int [ ] _inputOpsLengths ;
367
-
368
- public TensorFlowSchema ( IExceptionContext ectx , KeyValuePair < string , ColumnType > [ ] columns ,
369
- MetadataUtils . MetadataGetter < ReadOnlyMemory < char > > [ ] opTypeGetters ,
370
- MetadataUtils . MetadataGetter < VBuffer < ReadOnlyMemory < char > > > [ ] inputOpsGetters , int [ ] inputOpsLengths )
371
- : base ( ectx , columns )
372
- {
373
- ectx . CheckParam ( Utils . Size ( opTypeGetters ) == ColumnCount , nameof ( opTypeGetters ) ) ;
374
- ectx . CheckParam ( Utils . Size ( inputOpsGetters ) == ColumnCount , nameof ( inputOpsGetters ) ) ;
375
- ectx . CheckParam ( Utils . Size ( inputOpsLengths ) == ColumnCount , nameof ( inputOpsLengths ) ) ;
376
-
377
- _opTypeGetters = opTypeGetters ;
378
- _inputOpsGetters = inputOpsGetters ;
379
- _inputOpsLengths = inputOpsLengths ;
380
- }
381
-
382
- protected override void GetMetadataCore < TValue > ( string kind , int col , ref TValue value )
383
- {
384
- Ectx . Assert ( 0 <= col && col < ColumnCount ) ;
385
- if ( kind == OpType )
386
- _opTypeGetters [ col ] . Marshal ( col , ref value ) ;
387
- else if ( kind == InputOps && _inputOpsGetters [ col ] != null )
388
- _inputOpsGetters [ col ] . Marshal ( col , ref value ) ;
389
- else
390
- throw Ectx . ExceptGetMetadata ( ) ;
391
- }
392
-
393
- protected override ColumnType GetMetadataTypeOrNullCore ( string kind , int col )
394
- {
395
- Ectx . Assert ( 0 <= col && col < ColumnCount ) ;
396
- if ( kind == OpType )
397
- return TextType . Instance ;
398
- if ( kind == InputOps && _inputOpsGetters [ col ] != null )
399
- return new VectorType ( TextType . Instance , _inputOpsLengths [ col ] ) ;
400
- return null ;
401
- }
402
-
403
- protected override IEnumerable < KeyValuePair < string , ColumnType > > GetMetadataTypesCore ( int col )
404
- {
405
- Ectx . Assert ( 0 <= col && col < ColumnCount ) ;
406
- yield return new KeyValuePair < string , ColumnType > ( OpType , TextType . Instance ) ;
407
- if ( _inputOpsGetters [ col ] != null )
408
- yield return new KeyValuePair < string , ColumnType > ( InputOps , new VectorType ( TextType . Instance , _inputOpsLengths [ col ] ) ) ;
409
- }
410
- }
411
373
}
412
374
}
0 commit comments