Skip to content

Commit 4bfd7a1

Browse files
authored
Remove ISchema in TensorflowUtils.cs and polish its shape translation (dotnet#2123)
* Remove ISchema in TensorflowUtils.cs and polish its shape translation * Remove another ISchema because it's only used in Tensorflow * Rename Tensorflow metadata fields * Minor changes 1. Use VBuffer.CopyTo instead of "=" 2. Rename a variable
1 parent 5f9abe3 commit 4bfd7a1

File tree

3 files changed

+66
-178
lines changed

3 files changed

+66
-178
lines changed

src/Microsoft.ML.Data/DataView/SimpleRow.cs

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -69,80 +69,6 @@ public override bool IsColumnActive(int col)
6969
}
7070
}
7171

72-
/// <summary>
73-
/// An <see cref="ISchema"/> that takes all column names and types as constructor parameters.
74-
/// The columns do not have metadata.
75-
/// </summary>
76-
public abstract class SimpleSchemaBase : ISchema
77-
{
78-
protected readonly IExceptionContext Ectx;
79-
private readonly string[] _names;
80-
protected readonly ColumnType[] Types;
81-
protected readonly Dictionary<string, int> ColumnNameMap;
82-
83-
public int ColumnCount => Types.Length;
84-
85-
protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)
86-
{
87-
Contracts.CheckValueOrNull(ectx);
88-
Ectx = ectx;
89-
Ectx.CheckValue(columns, nameof(columns));
90-
91-
_names = new string[columns.Length];
92-
Types = new ColumnType[columns.Length];
93-
ColumnNameMap = new Dictionary<string, int>();
94-
for (int i = 0; i < columns.Length; i++)
95-
{
96-
_names[i] = columns[i].Key;
97-
Types[i] = columns[i].Value;
98-
if (ColumnNameMap.ContainsKey(columns[i].Key))
99-
throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'");
100-
ColumnNameMap[columns[i].Key] = i;
101-
}
102-
}
103-
104-
public bool TryGetColumnIndex(string name, out int col)
105-
{
106-
return ColumnNameMap.TryGetValue(name, out col);
107-
}
108-
109-
public string GetColumnName(int col)
110-
{
111-
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
112-
return _names[col];
113-
}
114-
115-
public ColumnType GetColumnType(int col)
116-
{
117-
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
118-
return Types[col];
119-
}
120-
121-
public IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypes(int col)
122-
{
123-
Ectx.Assert(0 <= col && col < ColumnCount);
124-
return GetMetadataTypesCore(col);
125-
}
126-
127-
protected abstract IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col);
128-
129-
public ColumnType GetMetadataTypeOrNull(string kind, int col)
130-
{
131-
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
132-
return GetMetadataTypeOrNullCore(kind, col);
133-
}
134-
135-
protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col);
136-
137-
public void GetMetadata<TValue>(string kind, int col, ref TValue value)
138-
{
139-
Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
140-
GetMetadataCore(kind, col, ref value);
141-
}
142-
143-
protected abstract void GetMetadataCore<TValue>(string kind, int col, ref TValue value);
144-
}
145-
14672
public static class SimpleSchemaUtils
14773
{
14874
public static Schema Create(IExceptionContext ectx, params KeyValuePair<string, ColumnType>[] columns)

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

Lines changed: 51 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -15,65 +15,77 @@ namespace Microsoft.ML.Transforms.TensorFlow
1515
{
1616
public static class TensorFlowUtils
1717
{
18-
public const string OpType = "OpType";
19-
public const string InputOps = "InputOps";
18+
/// <summary>
19+
/// Key to access operator's type (a string) in <see cref="Schema.Column.Metadata"/>.
20+
/// Its value describes the Tensorflow operator that produces this <see cref="Schema.Column"/>.
21+
/// </summary>
22+
public const string TensorflowOperatorTypeKind = "TensorflowOperatorType";
23+
/// <summary>
24+
/// Key to access upstream operators' names (a string array) in <see cref="Schema.Column.Metadata"/>.
25+
/// Its value states operators that the associated <see cref="Schema.Column"/>'s generator depends on.
26+
/// </summary>
27+
public const string TensorflowUpstreamOperatorsKind = "TensorflowUpstreamOperators";
2028

2129
internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
2230
{
23-
var res = new List<KeyValuePair<string, ColumnType>>();
24-
var opTypeGetters = new List<MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>>();
25-
var inputOpsGetters = new List<MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>>();
26-
var inputOpsLengths = new List<int>();
31+
var schemaBuilder = new SchemaBuilder();
2732
foreach (var op in graph)
2833
{
2934
if (opType != null && opType != op.OpType)
3035
continue;
36+
3137
var tfType = op[0].OutputType;
38+
// Determine element type in Tensorflow tensor. For example, a vector of floats may get NumberType.R4 here.
3239
var mlType = Tf2MlNetTypeOrNull(tfType);
3340

34-
// If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema.
41+
// If the type is not supported in ML.NET then we cannot represent it as a column in an Schema.
3542
// We also cannot output it with a TensorFlowTransform, so we skip it.
3643
if (mlType == null)
3744
continue;
3845

39-
var shape = graph.GetTensorShape(op[0]);
40-
var shapeArray = shape.ToIntArray();
41-
42-
inputOpsLengths.Add(op.NumInputs);
43-
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> inputOpsGetter = null;
46+
// Construct the final ML.NET type of a Tensorflow variable.
47+
var tensorShape = graph.GetTensorShape(op[0]).ToIntArray();
48+
var columnType = new VectorType(mlType);
49+
if (!(Utils.Size(tensorShape) == 1 && tensorShape[0] <= 0) &&
50+
(Utils.Size(tensorShape) > 0 && tensorShape.Skip(1).All(x => x > 0)))
51+
columnType = new VectorType(mlType, tensorShape[0] > 0 ? tensorShape : tensorShape.Skip(1).ToArray());
52+
53+
// There can be at most two metadata fields.
54+
// 1. The first field always presents. Its value is this operator's type. For example,
55+
// if an output is produced by an "Softmax" operator, the value of this field should be "Softmax".
56+
// 2. The second field stores operators whose outputs are consumed by this operator. In other words,
57+
// these values are names of some upstream operators which should be evaluated before executing
58+
// the current operator. It's possible that one operator doesn't need any input, so this field
59+
// can be missing.
60+
var metadataBuilder = new MetadataBuilder();
61+
// Create the first metadata field.
62+
metadataBuilder.Add(TensorflowOperatorTypeKind, TextType.Instance, (ref ReadOnlyMemory<char> value) => value = op.OpType.AsMemory());
4463
if (op.NumInputs > 0)
4564
{
46-
var inputOps = new ReadOnlyMemory<char>[op.NumInputs];
47-
for (int i = 0; i < op.NumInputs; i++)
48-
{
49-
var input = op.GetInput(i);
50-
inputOps[i] = new ReadOnlyMemory<char>(input.Operation.Name.ToArray());
51-
}
52-
inputOpsGetter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
53-
dst = new VBuffer<ReadOnlyMemory<char>>(op.NumInputs, inputOps);
65+
// Put upstream operators' names to an array (type: VBuffer) of string (type: ReadOnlyMemory<char>).
66+
VBuffer<ReadOnlyMemory<char>> upstreamOperatorNames = default;
67+
var bufferEditor = VBufferEditor.Create(ref upstreamOperatorNames, op.NumInputs);
68+
for (int i = 0; i < op.NumInputs; ++i)
69+
bufferEditor.Values[i] = op.GetInput(i).Operation.Name.AsMemory();
70+
upstreamOperatorNames = bufferEditor.Commit(); // Used in metadata's getter.
71+
72+
// Create the second metadata field.
73+
metadataBuilder.Add(TensorflowUpstreamOperatorsKind, new VectorType(TextType.Instance, op.NumInputs),
74+
(ref VBuffer<ReadOnlyMemory<char>> value) => { upstreamOperatorNames.CopyTo(ref value); });
5475
}
55-
inputOpsGetters.Add(inputOpsGetter);
56-
57-
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>> opTypeGetter =
58-
(int col, ref ReadOnlyMemory<char> dst) => dst = new ReadOnlyMemory<char>(op.OpType.ToArray());
59-
opTypeGetters.Add(opTypeGetter);
6076

61-
var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] <= 0 ? new VectorType(mlType) :
62-
Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ?
63-
new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray())
64-
: new VectorType(mlType);
65-
res.Add(new KeyValuePair<string, ColumnType>(op.Name, columnType));
77+
schemaBuilder.AddColumn(op.Name, columnType, metadataBuilder.GetMetadata());
6678
}
67-
return Schema.Create(new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray()));
79+
return schemaBuilder.GetSchema();
6880
}
6981

7082
/// <summary>
71-
/// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="ISchema"/>.
83+
/// This method retrieves the information about the graph nodes of a TensorFlow model as an <see cref="Schema"/>.
7284
/// For every node in the graph that has an output type that is compatible with the types supported by
7385
/// <see cref="TensorFlowTransformer"/>, the output schema contains a column with the name of that node, and the
7486
/// type of its output (including the item type and the shape, if it is known). Every column also contains metadata
75-
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
76-
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
87+
/// of kind <see cref="TensorflowOperatorTypeKind"/>, indicating the operation type of the node, and if that node has inputs in the graph,
88+
/// it contains metadata of kind <see cref="TensorflowUpstreamOperatorsKind"/>, indicating the names of the input nodes.
7789
/// </summary>
7890
/// <param name="env">The environment to use.</param>
7991
/// <param name="modelPath">Model to load.</param>
@@ -85,7 +97,7 @@ public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
8597

8698
/// <summary>
8799
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
88-
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
100+
/// iterates over the columns of the <see cref="Schema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
89101
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
90102
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
91103
/// </summary>
@@ -100,16 +112,16 @@ public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
100112
var name = schema[i].Name;
101113
var type = schema[i].Type;
102114

103-
var metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorFlowUtils.OpType)?.Type;
115+
var metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorflowOperatorTypeKind)?.Type;
104116
Contracts.Assert(metadataType != null && metadataType is TextType);
105117
ReadOnlyMemory<char> opType = default;
106-
schema[i].Metadata.GetValue(TensorFlowUtils.OpType, ref opType);
107-
metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorFlowUtils.InputOps)?.Type;
118+
schema[i].Metadata.GetValue(TensorflowOperatorTypeKind, ref opType);
119+
metadataType = schema[i].Metadata.Schema.GetColumnOrNull(TensorflowUpstreamOperatorsKind)?.Type;
108120
VBuffer<ReadOnlyMemory<char>> inputOps = default;
109121
if (metadataType != null)
110122
{
111123
Contracts.Assert(metadataType.IsKnownSizeVector && metadataType.ItemType is TextType);
112-
schema[i].Metadata.GetValue(TensorFlowUtils.InputOps, ref inputOps);
124+
schema[i].Metadata.GetValue(TensorflowUpstreamOperatorsKind, ref inputOps);
113125
}
114126

115127
string[] inputOpsResult = inputOps.DenseValues()
@@ -358,55 +370,5 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
358370
return false;
359371
}
360372
}
361-
362-
private sealed class TensorFlowSchema : SimpleSchemaBase
363-
{
364-
private readonly MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>[] _opTypeGetters;
365-
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _inputOpsGetters;
366-
private readonly int[] _inputOpsLengths;
367-
368-
public TensorFlowSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns,
369-
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>[] opTypeGetters,
370-
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] inputOpsGetters, int[] inputOpsLengths)
371-
: base(ectx, columns)
372-
{
373-
ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters));
374-
ectx.CheckParam(Utils.Size(inputOpsGetters) == ColumnCount, nameof(inputOpsGetters));
375-
ectx.CheckParam(Utils.Size(inputOpsLengths) == ColumnCount, nameof(inputOpsLengths));
376-
377-
_opTypeGetters = opTypeGetters;
378-
_inputOpsGetters = inputOpsGetters;
379-
_inputOpsLengths = inputOpsLengths;
380-
}
381-
382-
protected override void GetMetadataCore<TValue>(string kind, int col, ref TValue value)
383-
{
384-
Ectx.Assert(0 <= col && col < ColumnCount);
385-
if (kind == OpType)
386-
_opTypeGetters[col].Marshal(col, ref value);
387-
else if (kind == InputOps && _inputOpsGetters[col] != null)
388-
_inputOpsGetters[col].Marshal(col, ref value);
389-
else
390-
throw Ectx.ExceptGetMetadata();
391-
}
392-
393-
protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
394-
{
395-
Ectx.Assert(0 <= col && col < ColumnCount);
396-
if (kind == OpType)
397-
return TextType.Instance;
398-
if (kind == InputOps && _inputOpsGetters[col] != null)
399-
return new VectorType(TextType.Instance, _inputOpsLengths[col]);
400-
return null;
401-
}
402-
403-
protected override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col)
404-
{
405-
Ectx.Assert(0 <= col && col < ColumnCount);
406-
yield return new KeyValuePair<string, ColumnType>(OpType, TextType.Instance);
407-
if (_inputOpsGetters[col] != null)
408-
yield return new KeyValuePair<string, ColumnType>(InputOps, new VectorType(TextType.Instance, _inputOpsLengths[col]));
409-
}
410-
}
411373
}
412374
}

0 commit comments

Comments
 (0)