Skip to content

Commit e72f3f8

Browse files
committed
Remove ISchema in TensorflowUtils.cs and polish its shape translation
1 parent 447e94e commit e72f3f8

File tree

1 file changed

+40
-96
lines changed

1 file changed

+40
-96
lines changed

src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs

+40-96
Original file line numberDiff line numberDiff line change
@@ -20,51 +20,49 @@ public static class TensorFlowUtils
2020

2121
internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, string opType = null)
2222
{
23-
var res = new List<KeyValuePair<string, ColumnType>>();
24-
var opTypeGetters = new List<MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>>();
25-
var inputOpsGetters = new List<MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>>();
26-
var inputOpsLengths = new List<int>();
23+
var schemaBuilder = new SchemaBuilder();
2724
foreach (var op in graph)
2825
{
2926
if (opType != null && opType != op.OpType)
3027
continue;
28+
3129
var tfType = op[0].OutputType;
30+
// Determine element type in Tensorflow tensor. For example, a vector of floats may get NumberType.R4 here.
3231
var mlType = Tf2MlNetTypeOrNull(tfType);
3332

3433
// If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema.
3534
// We also cannot output it with a TensorFlowTransform, so we skip it.
3635
if (mlType == null)
3736
continue;
3837

39-
var shape = graph.GetTensorShape(op[0]);
40-
var shapeArray = shape.ToIntArray();
41-
42-
inputOpsLengths.Add(op.NumInputs);
43-
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>> inputOpsGetter = null;
38+
// Construct the final ML.NET type of a Tensorflow variable.
39+
var shapeArray = graph.GetTensorShape(op[0]).ToIntArray();
40+
var columnType = new VectorType(mlType);
41+
if (Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0))
42+
columnType = new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray());
43+
44+
// There can be at most two metadata fields.
45+
// 1. The first field always presents. Its value is this operator's ID. For example,
46+
// if an output is produced by an operator named "A", the value of this field should be "A".
47+
// 2. The second field stores operators whose outputs are consumed by this operator. In other words,
48+
// these values are names of some upstream operators which should be evaluated before executing
49+
// the current operator. It's possible that one operator doesn't need any input, so this field
50+
// can be missing.
51+
var metadataBuilder = new MetadataBuilder();
52+
metadataBuilder.Add(OpType, TextType.Instance, (ref ReadOnlyMemory<char> value) => value = op.OpType.AsMemory());
4453
if (op.NumInputs > 0)
45-
{
46-
var inputOps = new ReadOnlyMemory<char>[op.NumInputs];
47-
for (int i = 0; i < op.NumInputs; i++)
48-
{
49-
var input = op.GetInput(i);
50-
inputOps[i] = new ReadOnlyMemory<char>(input.Operation.Name.ToArray());
51-
}
52-
inputOpsGetter = (int col, ref VBuffer<ReadOnlyMemory<char>> dst) =>
53-
dst = new VBuffer<ReadOnlyMemory<char>>(op.NumInputs, inputOps);
54-
}
55-
inputOpsGetters.Add(inputOpsGetter);
56-
57-
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>> opTypeGetter =
58-
(int col, ref ReadOnlyMemory<char> dst) => dst = new ReadOnlyMemory<char>(op.OpType.ToArray());
59-
opTypeGetters.Add(opTypeGetter);
60-
61-
var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] <= 0 ? new VectorType(mlType) :
62-
Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ?
63-
new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray())
64-
: new VectorType(mlType);
65-
res.Add(new KeyValuePair<string, ColumnType>(op.Name, columnType));
54+
metadataBuilder.Add(InputOps, new VectorType(TextType.Instance, op.NumInputs),
55+
(ref VBuffer<ReadOnlyMemory<char>> value) =>
56+
{
57+
var bufferEditor = VBufferEditor.Create(ref value, op.NumInputs);
58+
for (int i = 0; i < op.NumInputs; ++i)
59+
bufferEditor.Values[i] = op.GetInput(i).Operation.Name.AsMemory();
60+
value = bufferEditor.Commit();
61+
});
62+
63+
schemaBuilder.AddColumn(op.Name, columnType, metadataBuilder.GetMetadata());
6664
}
67-
return Schema.Create(new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray()));
65+
return schemaBuilder.GetSchema();
6866
}
6967

7068
/// <summary>
@@ -75,25 +73,27 @@ internal static Schema GetModelSchema(IExceptionContext ectx, TFGraph graph, str
7573
/// of kind <see cref="OpType"/>, indicating the operation type of the node, and if that node has inputs in the graph,
7674
/// it contains metadata of kind <see cref="InputOps"/>, indicating the names of the input nodes.
7775
/// </summary>
78-
/// <param name="env">The environment to use.</param>
79-
/// <param name="modelPath">Model to load.</param>
80-
public static Schema GetModelSchema(IHostEnvironment env, string modelPath)
76+
/// <param name="ectx">An <see cref="IExceptionContext"/>.</param>
77+
/// <param name="modelFile">The name of the file containing the TensorFlow model. Currently only frozen model
78+
/// format is supported.</param>
79+
public static Schema GetModelSchema(IExceptionContext ectx, string modelFile)
8180
{
82-
var model = LoadTensorFlowModel(env, modelPath);
83-
return GetModelSchema(env, model.Session.Graph);
81+
var bytes = File.ReadAllBytes(modelFile);
82+
var session = LoadTFSession(ectx, bytes, modelFile);
83+
return GetModelSchema(ectx, session.Graph);
8484
}
8585

8686
/// <summary>
8787
/// This is a convenience method for iterating over the nodes of a TensorFlow model graph. It
88-
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IHostEnvironment, string)"/>,
88+
/// iterates over the columns of the <see cref="ISchema"/> returned by <see cref="GetModelSchema(IExceptionContext, string)"/>,
8989
/// and for each one it returns a tuple containing the name, operation type, column type and an array of input node names.
9090
/// This method is convenient for filtering nodes based on certain criteria, for example, by the operation type.
9191
/// </summary>
92-
/// <param name="modelPath">Model to load.</param>
92+
/// <param name="modelFile"></param>
9393
/// <returns></returns>
94-
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelPath)
94+
public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile)
9595
{
96-
var schema = GetModelSchema(new MLContext(), modelPath);
96+
var schema = GetModelSchema(null, modelFile);
9797

9898
for (int i = 0; i < schema.Count; i++)
9999
{
@@ -308,12 +308,6 @@ private static void CreateTempDirectoryWithAcl(string folder, string identity)
308308
}
309309
}
310310

311-
/// <summary>
312-
/// Load TensorFlow model into memory.
313-
/// </summary>
314-
/// <param name="env">The environment to use.</param>
315-
/// <param name="modelPath">The model to load.</param>
316-
/// <returns></returns>
317311
public static TensorFlowModelInfo LoadTensorFlowModel(IHostEnvironment env, string modelPath)
318312
{
319313
var session = GetSession(env, modelPath);
@@ -358,55 +352,5 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
358352
return false;
359353
}
360354
}
361-
362-
private sealed class TensorFlowSchema : SimpleSchemaBase
363-
{
364-
private readonly MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>[] _opTypeGetters;
365-
private readonly MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] _inputOpsGetters;
366-
private readonly int[] _inputOpsLengths;
367-
368-
public TensorFlowSchema(IExceptionContext ectx, KeyValuePair<string, ColumnType>[] columns,
369-
MetadataUtils.MetadataGetter<ReadOnlyMemory<char>>[] opTypeGetters,
370-
MetadataUtils.MetadataGetter<VBuffer<ReadOnlyMemory<char>>>[] inputOpsGetters, int[] inputOpsLengths)
371-
: base(ectx, columns)
372-
{
373-
ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters));
374-
ectx.CheckParam(Utils.Size(inputOpsGetters) == ColumnCount, nameof(inputOpsGetters));
375-
ectx.CheckParam(Utils.Size(inputOpsLengths) == ColumnCount, nameof(inputOpsLengths));
376-
377-
_opTypeGetters = opTypeGetters;
378-
_inputOpsGetters = inputOpsGetters;
379-
_inputOpsLengths = inputOpsLengths;
380-
}
381-
382-
protected override void GetMetadataCore<TValue>(string kind, int col, ref TValue value)
383-
{
384-
Ectx.Assert(0 <= col && col < ColumnCount);
385-
if (kind == OpType)
386-
_opTypeGetters[col].Marshal(col, ref value);
387-
else if (kind == InputOps && _inputOpsGetters[col] != null)
388-
_inputOpsGetters[col].Marshal(col, ref value);
389-
else
390-
throw Ectx.ExceptGetMetadata();
391-
}
392-
393-
protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
394-
{
395-
Ectx.Assert(0 <= col && col < ColumnCount);
396-
if (kind == OpType)
397-
return TextType.Instance;
398-
if (kind == InputOps && _inputOpsGetters[col] != null)
399-
return new VectorType(TextType.Instance, _inputOpsLengths[col]);
400-
return null;
401-
}
402-
403-
protected override IEnumerable<KeyValuePair<string, ColumnType>> GetMetadataTypesCore(int col)
404-
{
405-
Ectx.Assert(0 <= col && col < ColumnCount);
406-
yield return new KeyValuePair<string, ColumnType>(OpType, TextType.Instance);
407-
if (_inputOpsGetters[col] != null)
408-
yield return new KeyValuePair<string, ColumnType>(InputOps, new VectorType(TextType.Instance, _inputOpsLengths[col]));
409-
}
410-
}
411355
}
412356
}

0 commit comments

Comments
 (0)