diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln
index 211a6f7ca3..012d56e705 100644
--- a/Microsoft.ML.sln
+++ b/Microsoft.ML.sln
@@ -115,6 +115,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "sr
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}"
EndProject
+Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.ML.DnnAnalyzer", "src\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer\Microsoft.ML.DnnAnalyzer.csproj", "{73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}"
+EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
@@ -419,6 +421,14 @@ Global
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
{8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug|Any CPU.Build.0 = Debug|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.ActiveCfg = Release|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release|Any CPU.Build.0 = Release|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
@@ -466,6 +476,7 @@ Global
{570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
{8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4}
+ {73DAAC82-D308-48CC-8FFE-3B037F8BBCCA} = {09EADF06-BE25-4228-AB53-95AE3E15B530}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D}
diff --git a/src/Microsoft.ML.Data/DataView/SimpleRow.cs b/src/Microsoft.ML.Data/DataView/SimpleRow.cs
index b94f8f90ec..b0ba12b5ab 100644
--- a/src/Microsoft.ML.Data/DataView/SimpleRow.cs
+++ b/src/Microsoft.ML.Data/DataView/SimpleRow.cs
@@ -64,97 +64,135 @@ public bool IsColumnActive(int col)
/// An that takes all column names and types as constructor parameters.
/// The columns do not have metadata.
///
- public sealed class SimpleSchema : ISchema
+ public abstract class SimpleSchemaBase : ISchema
{
- private readonly IExceptionContext _ectx;
+ protected readonly IExceptionContext Ectx;
private readonly string[] _names;
- private readonly ColumnType[] _types;
- private readonly Dictionary _columnNameMap;
- private readonly MetadataUtils.MetadataGetter>>[] _keyValueGetters;
+ protected readonly ColumnType[] Types;
+ protected readonly Dictionary ColumnNameMap;
- public int ColumnCount => _types.Length;
+ public int ColumnCount => Types.Length;
- public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns)
+ protected SimpleSchemaBase(IExceptionContext ectx, params KeyValuePair[] columns)
{
Contracts.CheckValueOrNull(ectx);
- _ectx = ectx;
- _ectx.CheckValue(columns, nameof(columns));
+ Ectx = ectx;
+ Ectx.CheckValue(columns, nameof(columns));
_names = new string[columns.Length];
- _types = new ColumnType[columns.Length];
- _columnNameMap = new Dictionary();
+ Types = new ColumnType[columns.Length];
+ ColumnNameMap = new Dictionary();
for (int i = 0; i < columns.Length; i++)
{
_names[i] = columns[i].Key;
- _types[i] = columns[i].Value;
- if (_columnNameMap.ContainsKey(columns[i].Key))
+ Types[i] = columns[i].Value;
+ if (ColumnNameMap.ContainsKey(columns[i].Key))
throw ectx.ExceptParam(nameof(columns), $"Duplicate column name: '{columns[i].Key}'");
- _columnNameMap[columns[i].Key] = i;
- }
- _keyValueGetters = new MetadataUtils.MetadataGetter>>[ColumnCount];
- }
-
- public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns, Dictionary>>> keyValues)
- : this(ectx, columns)
- {
- foreach (var kvp in keyValues)
- {
- var name = kvp.Key;
- var getter = kvp.Value;
- if (!_columnNameMap.TryGetValue(name, out int col))
- throw _ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
- if (!_types[col].ItemType.IsKey)
- throw _ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
- _keyValueGetters[col] = getter;
+ ColumnNameMap[columns[i].Key] = i;
}
}
public bool TryGetColumnIndex(string name, out int col)
{
- return _columnNameMap.TryGetValue(name, out col);
+ return ColumnNameMap.TryGetValue(name, out col);
}
public string GetColumnName(int col)
{
- _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
return _names[col];
}
public ColumnType GetColumnType(int col)
{
- _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
- return _types[col];
+ Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ return Types[col];
}
public IEnumerable> GetMetadataTypes(int col)
{
- _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ Ectx.Assert(0 <= col && col < ColumnCount);
+ return GetMetadataTypesCore(col);
+ }
+
+ protected abstract IEnumerable> GetMetadataTypesCore(int col);
+
+ public ColumnType GetMetadataTypeOrNull(string kind, int col)
+ {
+ Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ return GetMetadataTypeOrNullCore(kind, col);
+ }
+
+ protected abstract ColumnType GetMetadataTypeOrNullCore(string kind, int col);
+
+ public void GetMetadata(string kind, int col, ref TValue value)
+ {
+ Ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ GetMetadataCore(kind, col, ref value);
+ }
+
+ protected abstract void GetMetadataCore(string kind, int col, ref TValue value);
+ }
+
+ ///
+ /// An that takes all column names and types as constructor parameters.
+ /// The columns can optionally have text metadata.
+ ///
+ public sealed class SimpleSchema : SimpleSchemaBase
+ {
+ private readonly MetadataUtils.MetadataGetter>>[] _keyValueGetters;
+
+ public SimpleSchema(IExceptionContext ectx, params KeyValuePair[] columns)
+ : base(ectx, columns)
+ {
+ _keyValueGetters = new MetadataUtils.MetadataGetter>>[ColumnCount];
+ }
+
+ public SimpleSchema(IExceptionContext ectx, KeyValuePair[] columns,
+ Dictionary>>> keyValues)
+ : this(ectx, columns)
+ {
+ foreach (var kvp in keyValues)
+ {
+ var name = kvp.Key;
+ var getter = kvp.Value;
+ if (!ColumnNameMap.TryGetValue(name, out int col))
+ throw Ectx.ExceptParam(nameof(keyValues), $"Output schema does not contain column '{name}'");
+ if (!Types[col].ItemType.IsKey)
+ throw Ectx.ExceptParam(nameof(keyValues), $"Column '{name}' is not a key column, so it cannot have key value metadata");
+ _keyValueGetters[col] = getter;
+ }
+ }
+
+ protected override IEnumerable> GetMetadataTypesCore(int col)
+ {
+ Ectx.Assert(0 <= col && col < ColumnCount);
if (_keyValueGetters[col] != null)
{
- _ectx.Assert(_types[col].ItemType.IsKey);
+ Ectx.Assert(Types[col].ItemType.IsKey);
yield return new KeyValuePair(MetadataUtils.Kinds.KeyValues,
- new VectorType(TextType.Instance, _types[col].ItemType.KeyCount));
+ new VectorType(TextType.Instance, Types[col].ItemType.KeyCount));
}
}
- public ColumnType GetMetadataTypeOrNull(string kind, int col)
+ protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
{
- _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
{
- _ectx.Assert(_types[col].ItemType.IsKey);
- return new VectorType(TextType.Instance, _types[col].ItemType.KeyCount);
+ Ectx.Assert(Types[col].ItemType.IsKey);
+ return new VectorType(TextType.Instance, Types[col].ItemType.KeyCount);
}
return null;
}
- public void GetMetadata(string kind, int col, ref TValue value)
+ protected override void GetMetadataCore(string kind, int col, ref TValue value)
{
- _ectx.CheckParam(0 <= col && col < ColumnCount, nameof(col));
+ Ectx.Assert(0 <= col && col < ColumnCount);
if (kind == MetadataUtils.Kinds.KeyValues && _keyValueGetters[col] != null)
_keyValueGetters[col].Marshal(col, ref value);
else
- throw _ectx.ExceptGetMetadata();
+ throw Ectx.ExceptGetMetadata();
}
}
}
\ No newline at end of file
diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs
new file mode 100644
index 0000000000..48fd32fc31
--- /dev/null
+++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/DnnAnalyzer.cs
@@ -0,0 +1,31 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Transforms.TensorFlow;
+using System;
+using System.Linq;
+
+namespace Microsoft.ML.DnnAnalyzer
+{
+ public static class DnnAnalyzer
+ {
+ public static void Main(string[] args)
+ {
+ if (Utils.Size(args) != 1)
+ {
+ Console.Error.WriteLine("Usage: dotnet DnnAnalyzer.dll ");
+ return;
+ }
+
+ foreach (var (name, opType, type, inputs) in TensorFlowUtils.GetModelNodes(args[0]))
+ {
+ var inputsString = inputs.Length == 0 ? "" : $", input nodes: {string.Join(", ", inputs)}";
+ Console.WriteLine($"Graph node: '{name}', operation type: '{opType}', output type: '{type}'{inputsString}");
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj
new file mode 100644
index 0000000000..7c77ff2ffa
--- /dev/null
+++ b/src/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer/Microsoft.ML.DnnAnalyzer.csproj
@@ -0,0 +1,19 @@
+
+
+
+ Exe
+ netcoreapp2.1
+ DnnAnalyzer
+ Microsoft.ML.TensorFlow
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
index 4fd4258794..e63e4f56c2 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/Tensorflow.cs
@@ -23,6 +23,7 @@
using size_t = System.UIntPtr;
using System.Collections.Generic;
+using System.Collections;
#pragma warning disable MSML_GeneralName
#pragma warning disable MSML_PrivateFieldName
@@ -492,7 +493,7 @@ public void SetConfig(IntPtr protoData, int length, TFStatus status = null)
/// "hot", and add a "sub" operation there the result will be "demo/hot/sub".
///
///
- internal partial class TFGraph : TFDisposableThreadSafe
+ internal partial class TFGraph : TFDisposableThreadSafe, IEnumerable
{
// extern TF_Graph * TF_NewGraph ();
[DllImport(NativeBinding.TensorFlowLibrary)]
@@ -696,6 +697,33 @@ public override string ToString()
IntPtr len;
return TF_GraphDebugString(Handle, out len);
}
+
+ [DllImport(NativeBinding.TensorFlowLibrary)]
+ private static unsafe extern TF_Operation TF_GraphNextOperation(TF_Graph graph, ref IntPtr pos);
+
+ ///
+ /// Returns the enumerator that returns all the TFOperations in a graph.
+ ///
+ /// The enumerator.
+ private IEnumerable GetEnumerable()
+ {
+ if (handle == IntPtr.Zero)
+ ObjectDisposedException();
+ IntPtr token = IntPtr.Zero;
+ IntPtr operll;
+ while ((operll = TF_GraphNextOperation(handle, ref token)) != IntPtr.Zero)
+ yield return new TFOperation(this, operll);
+ }
+
+ public IEnumerator GetEnumerator()
+ {
+ return GetEnumerable().GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return GetEnumerator();
+ }
}
///
@@ -736,6 +764,48 @@ public TFOutput this[int idx]
return new TFOutput(this, idx);
}
}
+
+ // extern TF_Output TF_OperationInput (TF_Input oper_in);
+ [DllImport(NativeBinding.TensorFlowLibrary)]
+ private static extern TFOutput TF_OperationInput(TFInput oper_in);
+
+ public TFOutput GetInput(int idx)
+ {
+ return TF_OperationInput(new TFInput() { Operation = handle, Index = idx });
+ }
+
+ [DllImport(NativeBinding.TensorFlowLibrary)]
+ private static extern IntPtr TF_OperationName(TF_Operation oper);
+
+ ///
+ /// The name for this operation/
+ ///
+ /// The name.
+ public string Name => handle == IntPtr.Zero ? "" : TF_OperationName(handle).GetStr();
+
+ [DllImport(NativeBinding.TensorFlowLibrary)]
+ private static extern IntPtr TF_OperationOpType(TF_Operation oper);
+
+ public string OpType => handle == IntPtr.Zero ? "" : TF_OperationOpType(handle).GetStr();
+
+ [DllImport(NativeBinding.TensorFlowLibrary)]
+ private static extern int TF_OperationNumOutputs(TF_Operation oper);
+
+ ///
+ /// Gets the number of outputs on this operation.
+ ///
+ /// The number outputs.
+ public int NumOutputs => handle == IntPtr.Zero ? -1 : TF_OperationNumOutputs(handle);
+
+ [DllImport(NativeBinding.TensorFlowLibrary)]
+ private static extern int TF_OperationNumInputs(TF_Operation oper);
+
+ ///
+ /// Gets the number of inputs for this operation.
+ /// Import a serialized graph into this graph, using the specified importing options.
+ ///
+ /// The number inputs.
+ public int NumInputs => TF_OperationNumInputs(handle);
}
///
@@ -1768,15 +1838,6 @@ internal struct TFInput
///
public int Index;
- // extern TF_Output TF_OperationInput (TF_Input oper_in);
- [DllImport(NativeBinding.TensorFlowLibrary)]
- private static extern TFOutput TF_OperationInput(TFInput oper_in);
-
- public TFOutput GetOutput(TFInput operIn)
- {
- return TF_OperationInput(operIn);
- }
-
// extern TF_DataType TF_OperationInputType (TF_Input oper_in);
[DllImport(NativeBinding.TensorFlowLibrary)]
private static extern TFDataType TF_OperationInputType(TFInput oper_in);
diff --git a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
index e77309c8b0..54030aec91 100644
--- a/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorFlow/TensorflowUtils.cs
@@ -4,16 +4,21 @@
using System;
using System.Collections.Generic;
+using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.ImageAnalytics.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Utilities;
namespace Microsoft.ML.Transforms.TensorFlow
{
public static class TensorFlowUtils
{
+ public const string OpType = "OpType";
+ public const string InputOps = "InputOps";
+
// This method is needed for the Pipeline API, since ModuleCatalog does not load entry points that are located
// in assemblies that aren't directly used in the code. Users who want to use TensorFlow components will have to call
// TensorFlowUtils.Initialize() before creating the pipeline.
@@ -25,7 +30,95 @@ public static void Initialize()
ImageAnalytics.Initialize();
}
+ private static ISchema GetModelSchema(IExceptionContext ectx, TFGraph graph)
+ {
+ var res = new List>();
+ var opTypeGetters = new List>>();
+ var inputOpsGetters = new List>>>();
+ var inputOpsLengths = new List();
+ foreach (var op in graph)
+ {
+ var tfType = op[0].OutputType;
+ var mlType = Tf2MlNetTypeOrNull(tfType);
+
+ // If the type is not supported in ML.NET then we cannot represent it as a column in an ISchema.
+ // We also cannot output it with a TensorFlowTransform, so we skip it.
+ if (mlType == null)
+ continue;
+
+ var shape = graph.GetTensorShape(op[0]);
+ var shapeArray = shape.ToIntArray();
+
+ inputOpsLengths.Add(op.NumInputs);
+ MetadataUtils.MetadataGetter>> inputOpsGetter = null;
+ if (op.NumInputs > 0)
+ {
+ var inputOps = new ReadOnlyMemory[op.NumInputs];
+ for (int i = 0; i < op.NumInputs; i++)
+ {
+ var input = op.GetInput(i);
+ inputOps[i] = new ReadOnlyMemory(input.Operation.Name.ToArray());
+ }
+ inputOpsGetter = (int col, ref VBuffer> dst) =>
+ dst = new VBuffer>(op.NumInputs, inputOps);
+ }
+ inputOpsGetters.Add(inputOpsGetter);
+
+ var opType = op.OpType;
+ MetadataUtils.MetadataGetter> opTypeGetter =
+ (int col, ref ReadOnlyMemory dst) => dst = new ReadOnlyMemory(opType.ToArray());
+ opTypeGetters.Add(opTypeGetter);
+
+ var columnType = Utils.Size(shapeArray) == 1 && shapeArray[0] == -1 ? new VectorType(mlType) :
+ Utils.Size(shapeArray) > 0 && shapeArray.Skip(1).All(x => x > 0) ?
+ new VectorType(mlType, shapeArray[0] > 0 ? shapeArray : shapeArray.Skip(1).ToArray())
+ : new VectorType(mlType);
+ res.Add(new KeyValuePair(op.Name, columnType));
+ }
+ return new TensorFlowSchema(ectx, res.ToArray(), opTypeGetters.ToArray(), inputOpsGetters.ToArray(), inputOpsLengths.ToArray());
+ }
+
+ public static ISchema GetModelSchema(IExceptionContext ectx, string modelFile)
+ {
+ var bytes = File.ReadAllBytes(modelFile);
+ var session = LoadTFSession(ectx, bytes, modelFile);
+ return GetModelSchema(ectx, session.Graph);
+ }
+
+ public static IEnumerable<(string, string, ColumnType, string[])> GetModelNodes(string modelFile)
+ {
+ var schema = GetModelSchema(null, modelFile);
+
+ for (int i = 0; i < schema.ColumnCount; i++)
+ {
+ var name = schema.GetColumnName(i);
+ var type = schema.GetColumnType(i);
+
+ var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, i);
+ Contracts.Assert(metadataType != null && metadataType.IsText);
+ ReadOnlyMemory opType = default;
+ schema.GetMetadata(TensorFlowUtils.OpType, i, ref opType);
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, i);
+ VBuffer> inputOps = default;
+ if (metadataType != null)
+ {
+ Contracts.Assert(metadataType.IsKnownSizeVector && metadataType.ItemType.IsText);
+ schema.GetMetadata(TensorFlowUtils.InputOps, i, ref inputOps);
+ }
+ yield return (name, opType.ToString(), type,
+ Utils.Size(inputOps.Values) > 0 ? inputOps.Values.Select(input => input.ToString()).ToArray() : new string[0]);
+ }
+ }
+
internal static PrimitiveType Tf2MlNetType(TFDataType type)
+ {
+ var mlNetType = Tf2MlNetTypeOrNull(type);
+ if (mlNetType == null)
+ throw new NotSupportedException("TensorFlow type not supported.");
+ return mlNetType;
+ }
+
+ private static PrimitiveType Tf2MlNetTypeOrNull(TFDataType type)
{
switch (type)
{
@@ -42,10 +135,29 @@ internal static PrimitiveType Tf2MlNetType(TFDataType type)
case TFDataType.UInt64:
return NumberType.U8;
default:
- throw new NotSupportedException("TensorFlow type not supported.");
+ return null;
}
}
+ internal static TFSession LoadTFSession(IExceptionContext ectx, byte[] modelBytes, string modelFile = null)
+ {
+ var graph = new TFGraph();
+ try
+ {
+ graph.Import(modelBytes, "");
+ }
+ catch (Exception ex)
+ {
+ if (!string.IsNullOrEmpty(modelFile))
+ throw ectx.Except($"TensorFlow exception triggered while loading model from '{modelFile}'");
+#pragma warning disable MSML_NoMessagesForLoadContext
+ throw ectx.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
+#pragma warning restore MSML_NoMessagesForLoadContext
+
+ }
+ return new TFSession(graph);
+ }
+
internal static unsafe void FetchData(IntPtr data, T[] result)
{
var size = result.Length;
@@ -73,5 +185,55 @@ internal static bool IsTypeSupported(TFDataType tfoutput)
return false;
}
}
+
+ private sealed class TensorFlowSchema : SimpleSchemaBase
+ {
+ private readonly MetadataUtils.MetadataGetter>[] _opTypeGetters;
+ private readonly MetadataUtils.MetadataGetter>>[] _inputOpsGetters;
+ private readonly int[] _inputOpsLengths;
+
+ public TensorFlowSchema(IExceptionContext ectx, KeyValuePair[] columns,
+ MetadataUtils.MetadataGetter>[] opTypeGetters,
+ MetadataUtils.MetadataGetter>>[] inputOpsGetters, int[] inputOpsLengths)
+ : base(ectx, columns)
+ {
+ ectx.CheckParam(Utils.Size(opTypeGetters) == ColumnCount, nameof(opTypeGetters));
+ ectx.CheckParam(Utils.Size(inputOpsGetters) == ColumnCount, nameof(inputOpsGetters));
+ ectx.CheckParam(Utils.Size(inputOpsLengths) == ColumnCount, nameof(inputOpsLengths));
+
+ _opTypeGetters = opTypeGetters;
+ _inputOpsGetters = inputOpsGetters;
+ _inputOpsLengths = inputOpsLengths;
+ }
+
+ protected override void GetMetadataCore(string kind, int col, ref TValue value)
+ {
+ Ectx.Assert(0 <= col && col < ColumnCount);
+ if (kind == OpType)
+ _opTypeGetters[col].Marshal(col, ref value);
+ else if (kind == InputOps && _inputOpsGetters[col] != null)
+ _inputOpsGetters[col].Marshal(col, ref value);
+ else
+ throw Ectx.ExceptGetMetadata();
+ }
+
+ protected override ColumnType GetMetadataTypeOrNullCore(string kind, int col)
+ {
+ Ectx.Assert(0 <= col && col < ColumnCount);
+ if (kind == OpType)
+ return TextType.Instance;
+ if (kind == InputOps && _inputOpsGetters[col] != null)
+ return new VectorType(TextType.Instance, _inputOpsLengths[col]);
+ return null;
+ }
+
+ protected override IEnumerable> GetMetadataTypesCore(int col)
+ {
+ Ectx.Assert(0 <= col && col < ColumnCount);
+ yield return new KeyValuePair(OpType, TextType.Instance);
+ if (_inputOpsGetters[col] != null)
+ yield return new KeyValuePair(InputOps, new VectorType(TextType.Instance, _inputOpsLengths[col]));
+ }
+ }
}
}
diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
index 578a9d5778..69532de8fc 100644
--- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
+++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs
@@ -122,6 +122,7 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext
byte[] modelBytes = null;
if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray()))
throw env.ExceptDecode();
+ var session = TensorFlowUtils.LoadTFSession(env, modelBytes);
var numInputs = ctx.Reader.ReadInt32();
env.CheckDecode(numInputs > 0);
string[] inputs = new string[numInputs];
@@ -138,7 +139,7 @@ private static TensorFlowTransform Create(IHostEnvironment env, ModelLoadContext
for (int j = 0; j < outputs.Length; j++)
outputs[j] = ctx.LoadNonEmptyString();
- return new TensorFlowTransform(env, modelBytes, inputs, outputs);
+ return new TensorFlowTransform(env, session, inputs, outputs);
}
// Factory method for SignatureDataTransform.
@@ -160,27 +161,12 @@ private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx,
private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema)
=> Create(env, ctx).MakeRowMapper(inputSchema);
- private TFSession LoadTFSession(byte[] modelBytes)
- {
- var graph = new TFGraph();
- try
- {
- graph.Import(modelBytes, "");
- }
- catch (Exception ex)
- {
-#pragma warning disable MSML_NoMessagesForLoadContext
- throw _host.ExceptDecode(ex, "Tensorflow exception triggered while loading model.");
-#pragma warning restore MSML_NoMessagesForLoadContext
- }
- return new TFSession(graph);
- }
-
- private static byte[] CheckFileAndRead(IHostEnvironment env, string modelFile)
+ private static TFSession CheckFileAndRead(IHostEnvironment env, string modelFile)
{
env.CheckNonWhiteSpace(modelFile, nameof(modelFile));
env.CheckUserArg(File.Exists(modelFile), nameof(modelFile));
- return File.ReadAllBytes(modelFile);
+ var bytes = File.ReadAllBytes(modelFile);
+ return TensorFlowUtils.LoadTFSession(env, bytes, modelFile);
}
public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inputs, string[] outputs) :
@@ -188,15 +174,14 @@ public TensorFlowTransform(IHostEnvironment env, string modelFile, string[] inpu
{
}
- private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] inputs, string[] outputs)
+ private TensorFlowTransform(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs)
{
Contracts.CheckValue(env, nameof(env));
_host = env.Register(nameof(RegistrationName));
- _host.CheckValue(modelBytes, nameof(modelBytes));
+ _host.CheckValue(session, nameof(session));
_host.CheckNonEmpty(inputs, nameof(inputs));
_host.CheckNonEmpty(outputs, nameof(outputs));
-
- Session = LoadTFSession(modelBytes);
+ Session = session;
foreach (var input in inputs)
{
_host.CheckNonWhiteSpace(input, nameof(inputs));
@@ -204,7 +189,7 @@ private TensorFlowTransform(IHostEnvironment env, byte[] modelBytes, string[] in
throw _host.ExceptParam(nameof(inputs), $"Input column '{input}' does not exist in the model");
var tfInput = new TFOutput(Session.Graph[input]);
if (!TensorFlowUtils.IsTypeSupported(tfInput.OutputType))
- throw _host.ExceptParam(nameof(modelBytes), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
+ throw _host.ExceptParam(nameof(session), $"Input type '{tfInput.OutputType}' of input column '{input}' is not supported in TensorFlow");
}
var newNames = new HashSet();
diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
index da9102fbb2..71f5f95f33 100644
--- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
+++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs
@@ -10,6 +10,7 @@
using Microsoft.ML.Runtime.LightGBM;
using Microsoft.ML.Transforms;
using Microsoft.ML.Transforms.TensorFlow;
+using System;
using System.Collections.Generic;
using System.IO;
using Xunit;
@@ -181,6 +182,95 @@ public void TensorFlowTransformInceptionTest()
}
}
+ [Fact]
+ public void TensorFlowInputsOutputsSchemaTest()
+ {
+ using (var env = new ConsoleEnvironment(seed: 1, conc: 1))
+ {
+ var model_location = "mnist_model/frozen_saved_model.pb";
+ var schema = TensorFlowUtils.GetModelSchema(env, model_location);
+ Assert.Equal(54, schema.ColumnCount);
+ Assert.True(schema.TryGetColumnIndex("Placeholder", out int col));
+ var type = schema.GetColumnType(col).AsVector;
+ Assert.Equal(2, type.DimCount);
+ Assert.Equal(28, type.GetDim(0));
+ Assert.Equal(28, type.GetDim(1));
+ var metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
+ Assert.NotNull(metadataType);
+ Assert.True(metadataType.IsText);
+ ReadOnlyMemory opType = default;
+ schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
+ Assert.Equal("Placeholder", opType.ToString());
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
+ Assert.Null(metadataType);
+
+ Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D/ReadVariableOp", out col));
+ type = schema.GetColumnType(col).AsVector;
+ Assert.Equal(4, type.DimCount);
+ Assert.Equal(5, type.GetDim(0));
+ Assert.Equal(5, type.GetDim(1));
+ Assert.Equal(1, type.GetDim(2));
+ Assert.Equal(32, type.GetDim(3));
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
+ Assert.NotNull(metadataType);
+ Assert.True(metadataType.IsText);
+ schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
+ Assert.Equal("Identity", opType.ToString());
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
+ Assert.NotNull(metadataType);
+ VBuffer> inputOps = default;
+ schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps);
+ Assert.Equal(1, inputOps.Length);
+ Assert.Equal("conv2d/kernel", inputOps.Values[0].ToString());
+
+ Assert.True(schema.TryGetColumnIndex("conv2d/Conv2D", out col));
+ type = schema.GetColumnType(col).AsVector;
+ Assert.Equal(3, type.DimCount);
+ Assert.Equal(28, type.GetDim(0));
+ Assert.Equal(28, type.GetDim(1));
+ Assert.Equal(32, type.GetDim(2));
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
+ Assert.NotNull(metadataType);
+ Assert.True(metadataType.IsText);
+ schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
+ Assert.Equal("Conv2D", opType.ToString());
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
+ Assert.NotNull(metadataType);
+ schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps);
+ Assert.Equal(2, inputOps.Length);
+ Assert.Equal("reshape/Reshape", inputOps.Values[0].ToString());
+ Assert.Equal("conv2d/Conv2D/ReadVariableOp", inputOps.Values[1].ToString());
+
+ Assert.True(schema.TryGetColumnIndex("Softmax", out col));
+ type = schema.GetColumnType(col).AsVector;
+ Assert.Equal(1, type.DimCount);
+ Assert.Equal(10, type.GetDim(0));
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.OpType, col);
+ Assert.NotNull(metadataType);
+ Assert.True(metadataType.IsText);
+ schema.GetMetadata(TensorFlowUtils.OpType, col, ref opType);
+ Assert.Equal("Softmax", opType.ToString());
+ metadataType = schema.GetMetadataTypeOrNull(TensorFlowUtils.InputOps, col);
+ Assert.NotNull(metadataType);
+ schema.GetMetadata(TensorFlowUtils.InputOps, col, ref inputOps);
+ Assert.Equal(1, inputOps.Length);
+ Assert.Equal("sequential/dense_1/BiasAdd", inputOps.Values[0].ToString());
+
+ model_location = "model_matmul/frozen_saved_model.pb";
+ schema = TensorFlowUtils.GetModelSchema(env, model_location);
+ char name = 'a';
+ for (int i = 0; i < schema.ColumnCount; i++)
+ {
+ Assert.Equal(name.ToString(), schema.GetColumnName(i));
+ type = schema.GetColumnType(i).AsVector;
+ Assert.Equal(2, type.DimCount);
+ Assert.Equal(2, type.GetDim(0));
+ Assert.Equal(2, type.GetDim(1));
+ name++;
+ }
+ }
+ }
+
[Fact]
public void TensorFlowTransformMNISTConvTest()
{