diff --git a/Microsoft.ML.sln b/Microsoft.ML.sln index db998668dd..1421025ec9 100644 --- a/Microsoft.ML.sln +++ b/Microsoft.ML.sln @@ -93,7 +93,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Microsoft.ML.CpuMath", "Mic EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "tools-local", "tools-local", "{7F13E156-3EBA-4021-84A5-CD56BA72F99E}" EndProject -Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer", "tools-local\Microsoft.ML.CodeAnalyzer\Microsoft.ML.CodeAnalyzer.csproj", "{B4E55B2D-2A92-46E7-B72F-E76D6FD83440}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.InternalCodeAnalyzer", "tools-local\Microsoft.ML.InternalCodeAnalyzer\Microsoft.ML.InternalCodeAnalyzer.csproj", "{B4E55B2D-2A92-46E7-B72F-E76D6FD83440}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.CodeAnalyzer.Tests", "test\Microsoft.ML.CodeAnalyzer.Tests\Microsoft.ML.CodeAnalyzer.Tests.csproj", "{3E4ABF07-7970-4BE6-B45B-A13D3C397545}" EndProject @@ -111,6 +111,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.HalLearners", EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.TensorFlow", "src\Microsoft.ML.TensorFlow\Microsoft.ML.TensorFlow.csproj", "{570A0B8A-5463-44D2-8521-54C0CA4CACA9}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.Analyzer", "src\Microsoft.ML.Analyzer\Microsoft.ML.Analyzer.csproj", "{6DEF0F40-3853-47B3-8165-5F24BA5E14DF}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.ML.StaticPipelineTesting", "test\Microsoft.ML.StaticPipelineTesting\Microsoft.ML.StaticPipelineTesting.csproj", "{8B38BF24-35F4-4787-A9C5-22D35987106E}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -399,6 +403,22 @@ Global {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release|Any CPU.Build.0 = Release|Any CPU {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU {570A0B8A-5463-44D2-8521-54C0CA4CACA9}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug|Any CPU.Build.0 = Debug|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release|Any CPU.ActiveCfg = Release|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release|Any CPU.Build.0 = Release|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug-Intrinsics|Any CPU.ActiveCfg = Debug|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Debug-Intrinsics|Any CPU.Build.0 = Debug|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release|Any CPU.Build.0 = Release|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.ActiveCfg = Release|Any CPU + {8B38BF24-35F4-4787-A9C5-22D35987106E}.Release-Intrinsics|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -444,6 +464,8 @@ Global {00E38F77-1E61-4CDF-8F97-1417D4E85053} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {A7222F41-1CF0-47D9-B80C-B4D77B027A61} = {09EADF06-BE25-4228-AB53-95AE3E15B530} {570A0B8A-5463-44D2-8521-54C0CA4CACA9} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {6DEF0F40-3853-47B3-8165-5F24BA5E14DF} = {09EADF06-BE25-4228-AB53-95AE3E15B530} + {8B38BF24-35F4-4787-A9C5-22D35987106E} = {AED9C836-31E3-4F3F-8ABC-929555D3F3C4} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {41165AF1-35BB-4832-A189-73060F82B01D} diff --git a/build/Dependencies.props b/build/Dependencies.props index e880e8c66b..07f59e50e9 100644 --- a/build/Dependencies.props +++ b/build/Dependencies.props @@ -12,5 +12,9 @@ 4.5.0 0.11.1 1.10.0 + + 2.9.0 + 4.5.0 + 1.2.0 diff --git a/src/Directory.Build.props b/src/Directory.Build.props index ee32523d8e..6c6117e7dd 100644 --- a/src/Directory.Build.props +++ b/src/Directory.Build.props @@ -26,7 +26,7 @@ + Include="$(MSBuildThisFileDirectory)\..\tools-local\Microsoft.ML.InternalCodeAnalyzer\Microsoft.ML.InternalCodeAnalyzer.csproj"> false Analyzer diff --git a/src/Microsoft.ML.Analyzer/Microsoft.ML.Analyzer.csproj b/src/Microsoft.ML.Analyzer/Microsoft.ML.Analyzer.csproj new file mode 100644 index 0000000000..afe45e57f5 --- /dev/null +++ b/src/Microsoft.ML.Analyzer/Microsoft.ML.Analyzer.csproj @@ -0,0 +1,13 @@ + + + + netstandard1.3 + + + + + + + + + diff --git a/src/Microsoft.ML.Analyzer/TypeIsSchemaShapeAnalyzer.cs b/src/Microsoft.ML.Analyzer/TypeIsSchemaShapeAnalyzer.cs new file mode 100644 index 0000000000..f5b9b240ee --- /dev/null +++ b/src/Microsoft.ML.Analyzer/TypeIsSchemaShapeAnalyzer.cs @@ -0,0 +1,180 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Immutable; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Diagnostics; + +namespace Microsoft.ML.Analyzer +{ + [DiagnosticAnalyzer(LanguageNames.CSharp)] + public sealed class TypeIsSchemaShapeAnalyzer : DiagnosticAnalyzer + { + internal static class ShapeDiagnostic + { + private const string Category = "Type Check"; + public const string Id = "MSML_TypeShouldBeSchemaShape"; + private const string Title = "The type is not a schema shape"; + private const string Format = "Type{0} is neither a PipelineColumn nor a ValueTuple."; + internal const string Description = + "Within statically typed pipeline elements of ML.NET, the shape of the schema is determined by a type. " + + "A valid type is either an instance of one of the PipelineColumn subclasses (e.g., Scalar " + + "or something like that), or a ValueTuple containing only valid types. (So, ValueTuples containing " + + "other value tuples are fine, so long as they terminate in a PipelineColumn subclass.)"; + + internal static DiagnosticDescriptor Rule = + new DiagnosticDescriptor(Id, Title, Format, Category, + DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description); + } + + internal static class ShapeParameterDiagnostic + { + private const string Category = "Type Check"; + public const string Id = "MSML_TypeParameterShouldBeSchemaShape"; + private const string Title = "The type is not a schema shape"; + private const string Format = "Type parameter {0} is not marked with [IsShape] or appropriate type constraints."; + internal const string Description = ShapeDiagnostic.Description + " " + + "If using type parameters when interacting with the statically typed pipelines, the type parameter ought to be " + + "constrained in such a way that it, either by applying the [IsShape] attribute or by having type constraints to " + + "indicate that it is valid, e.g., constraining the type to descend from PipelineColumn."; + + internal static DiagnosticDescriptor Rule = + new DiagnosticDescriptor(Id, Title, Format, Category, + DiagnosticSeverity.Error, isEnabledByDefault: true, description: Description); + } + + private const string AttributeName = "Microsoft.ML.Data.StaticPipe.IsShapeAttribute"; + private const string LeafTypeName = "Microsoft.ML.Data.StaticPipe.Runtime.PipelineColumn"; + + public override ImmutableArray SupportedDiagnostics => + ImmutableArray.Create(ShapeDiagnostic.Rule, ShapeParameterDiagnostic.Rule); + + public override void Initialize(AnalysisContext context) + { + context.RegisterSemanticModelAction(Analyze); + } + + private void Analyze(SemanticModelAnalysisContext context) + { + // We start with the model, then do the the method invocations. + // We could have phrased it as RegisterSyntaxNodeAction(Analyze, SyntaxKind.InvocationExpression), + // but this seemed more inefficient since getting the model and fetching the type symbols every + // single time seems to incur significant cost. The following invocation is somewhat more awkward + // since we must iterate over the invocation syntaxes ourselves, but this seems to be worthwhile. + var model = context.SemanticModel; + var comp = model.Compilation; + + // Get the symbols of the key types we are analyzing. If we can't find any of them there is + // no point in going further. + var attrType = comp.GetTypeByMetadataName(AttributeName); + if (attrType == null) + return; + var leafType = comp.GetTypeByMetadataName(LeafTypeName); + if (leafType == null) + return; + + // This internal helper method recursively determines whether an attributed type parameter + // has a valid type. It is called externally from the loop over invocations. + bool CheckType(ITypeSymbol type, out string path, out ITypeSymbol problematicType) + { + if (type.TypeKind == TypeKind.TypeParameter) + { + var typeParam = (ITypeParameterSymbol)type; + path = null; + problematicType = null; + // Does the type parameter have the attribute that triggers a check? + if (type.GetAttributes().Any(attr => attr.AttributeClass == attrType)) + return true; + // Are any of the declared constraint types OK? + if (typeParam.ConstraintTypes.Any(ct => CheckType(ct, out string ctPath, out var ctProb))) + return true; + // Well, probably not good then. Let's call it a day. + problematicType = typeParam; + return false; + } + else if (type.IsTupleType) + { + INamedTypeSymbol nameType = (INamedTypeSymbol)type; + var tupleElems = nameType.TupleElements; + + for (int i = 0; i < tupleElems.Length; ++i) + { + var e = tupleElems[i]; + if (!CheckType(e.Type, out string innerPath, out problematicType)) + { + path = e.Name ?? $"Item{i + 1}"; + if (innerPath != null) + path += "." + innerPath; + return false; + } + } + path = null; + problematicType = null; + return true; + } + else + { + for (var rt = type; rt != null; rt = rt.BaseType) + { + if (rt == leafType) + { + path = null; + problematicType = null; + return true; + } + } + path = null; + problematicType = type; + return false; + } + } + + foreach (var invocation in model.SyntaxTree.GetRoot().DescendantNodes().OfType()) + { + var symbolInfo = model.GetSymbolInfo(invocation); + if (!(symbolInfo.Symbol is IMethodSymbol methodSymbol)) + { + // Should we perhaps skip when there is a method resolution failure? This is often but not always a sign of another problem. + if (symbolInfo.CandidateReason != CandidateReason.OverloadResolutionFailure || symbolInfo.CandidateSymbols.Length == 0) + continue; + methodSymbol = symbolInfo.CandidateSymbols[0] as IMethodSymbol; + if (methodSymbol == null) + continue; + } + // Analysis only applies to generic methods. + if (!methodSymbol.IsGenericMethod) + continue; + // Scan the type parameters for one that has our target attribute. + for (int i = 0; i < methodSymbol.TypeParameters.Length; ++i) + { + var par = methodSymbol.TypeParameters[i]; + var attr = par.GetAttributes(); + if (attr.Length == 0) + continue; + if (!attr.Any(a => a.AttributeClass == attrType)) + continue; + // We've found it. Check the type argument to ensure it is of the appropriate type. + var p = methodSymbol.TypeArguments[i]; + if (CheckType(p, out string path, out ITypeSymbol problematicType)) + continue; + + if (problematicType.Kind == SymbolKind.TypeParameter) + { + var diagnostic = Diagnostic.Create(ShapeParameterDiagnostic.Rule, invocation.GetLocation(), problematicType.Name); + context.ReportDiagnostic(diagnostic); + } + else + { + path = path == null ? "" : " of item " + path; + var diagnostic = Diagnostic.Create(ShapeDiagnostic.Rule, invocation.GetLocation(), path); + context.ReportDiagnostic(diagnostic); + } + } + } + } + } +} diff --git a/src/Microsoft.ML.Core/Data/DataKind.cs b/src/Microsoft.ML.Core/Data/DataKind.cs index 32325f44a1..0249745691 100644 --- a/src/Microsoft.ML.Core/Data/DataKind.cs +++ b/src/Microsoft.ML.Core/Data/DataKind.cs @@ -30,7 +30,7 @@ public enum DataKind : byte Num = R4, TX = 11, -#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independnet of C# naming conventions. +#pragma warning disable MSML_GeneralName // The data kind enum has its own logic, independent of C# naming conventions. TXT = TX, Text = TX, diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 509a67bb4f..54b0c64cd9 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -192,11 +192,10 @@ public interface IDataReader public interface IDataReaderEstimator where TReader : IDataReader { + // REVIEW: you could consider the transformer to take a different , but we don't have such components + // yet, so why complicate matters? /// /// Train and return a data reader. - /// - /// REVIEW: you could consider the transformer to take a different , but we don't have such components - /// yet, so why complicate matters? /// TReader Fit(TSource input); diff --git a/src/Microsoft.ML.Core/Utilities/Utils.cs b/src/Microsoft.ML.Core/Utilities/Utils.cs index 96c23a0fe3..ac100b6dc9 100644 --- a/src/Microsoft.ML.Core/Utilities/Utils.cs +++ b/src/Microsoft.ML.Core/Utilities/Utils.cs @@ -1049,6 +1049,24 @@ public static void MarshalActionInvoke(Action act, Type genArg, TA meth.Invoke(act.Target, new object[] { arg1 }); } + /// + /// A two-argument version of . + /// + public static void MarshalActionInvoke(Action act, Type genArg, TArg1 arg1, TArg2 arg2) + { + var meth = MarshalActionInvokeCheckAndCreate(genArg, act); + meth.Invoke(act.Target, new object[] { arg1, arg2 }); + } + + /// + /// A three-argument version of . + /// + public static void MarshalActionInvoke(Action act, Type genArg, TArg1 arg1, TArg2 arg2, TArg3 arg3) + { + var meth = MarshalActionInvokeCheckAndCreate(genArg, act); + meth.Invoke(act.Target, new object[] { arg1, arg2, arg3 }); + } + public static string GetDescription(this Enum value) { Type type = value.GetType(); diff --git a/src/Microsoft.ML.Data/Data/IColumn.cs b/src/Microsoft.ML.Data/Data/IColumn.cs index 2f2f496f99..04000c2d38 100644 --- a/src/Microsoft.ML.Data/Data/IColumn.cs +++ b/src/Microsoft.ML.Data/Data/IColumn.cs @@ -97,6 +97,30 @@ private static IColumn GetColumnCore(IRow row, int col) return new RowWrap(row, col); } + /// + /// Exposes a single column in a schema. The column is considered inactive. + /// + /// The schema to get the data for + /// The column to get + /// A column with false + public static IColumn GetColumn(ISchema schema, int col) + { + Contracts.CheckValue(schema, nameof(schema)); + Contracts.CheckParam(0 <= col && col < schema.ColumnCount, nameof(col)); + + Func func = GetColumnCore; + return Utils.MarshalInvoke(func, schema.GetColumnType(col).RawType, schema, col); + } + + private static IColumn GetColumnCore(ISchema schema, int col) + { + Contracts.AssertValue(schema); + Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T)); + + return new SchemaWrap(schema, col); + } + /// /// Wraps the metadata of a column as a row. /// @@ -183,7 +207,7 @@ public static IColumn GetColumn(string name, ColumnType type, ValueGetter /// , or if null, the output row will yield default values for those implementations, /// that is, a totally static row /// A set of row columns - /// A row with items derived from + /// A row with items derived from public static IRow GetRow(ICounted counted, params IColumn[] columns) { Contracts.CheckValueOrNull(counted); @@ -229,9 +253,9 @@ private sealed class RowWrap : IColumn private readonly int _col; private MetadataRow _meta; - public string Name { get { return _row.Schema.GetColumnName(_col); } } - public ColumnType Type { get { return _row.Schema.GetColumnType(_col); } } - public bool IsActive { get { return _row.IsColumnActive(_col); } } + public string Name => _row.Schema.GetColumnName(_col); + public ColumnType Type => _row.Schema.GetColumnType(_col); + public bool IsActive => _row.IsColumnActive(_col); public IRow Metadata { @@ -254,14 +278,10 @@ public RowWrap(IRow row, int col) } Delegate IColumn.GetGetter() - { - return GetGetter(); - } + => GetGetter(); public ValueGetter GetGetter() - { - return _row.GetGetter(_col); - } + => _row.GetGetter(_col); } /// @@ -269,17 +289,53 @@ public ValueGetter GetGetter() /// private abstract class DefaultCounted : ICounted { - public long Position { get { return 0; } } - public long Batch { get { return 0; } } + public long Position => 0; + public long Batch => 0; public ValueGetter GetIdGetter() + => IdGetter; + + private static void IdGetter(ref UInt128 id) + => id = default; + } + + /// + /// Simple wrapper for a schema column, considered inctive with no getter. + /// + /// The type of the getter + private sealed class SchemaWrap : IColumn + { + private readonly ISchema _schema; + private readonly int _col; + private MetadataRow _meta; + + public string Name => _schema.GetColumnName(_col); + public ColumnType Type => _schema.GetColumnType(_col); + public bool IsActive => false; + + public IRow Metadata { - return IdGetter; + get { + if (_meta == null) + Interlocked.CompareExchange(ref _meta, new MetadataRow(_schema, _col), null); + return _meta; + } } - private static void IdGetter(ref UInt128 id) + public SchemaWrap(ISchema schema, int col) { - id = default(UInt128); + Contracts.AssertValue(schema); + Contracts.Assert(0 <= col && col < schema.ColumnCount); + Contracts.Assert(schema.GetColumnType(col).RawType == typeof(T)); + + _schema = schema; + _col = col; } + + Delegate IColumn.GetGetter() + => GetGetter(); + + public ValueGetter GetGetter() + => throw Contracts.Except("Column not active"); } /// @@ -289,7 +345,7 @@ private static void IdGetter(ref UInt128 id) /// private sealed class MetadataRow : DefaultCounted, IRow { - public ISchema Schema { get { return _schema; } } + public ISchema Schema => _schema; private readonly ISchema _metaSchema; private readonly int _col; @@ -379,13 +435,9 @@ public ValueGetter GetGetter(int col) /// private abstract class SimpleColumnBase : IColumn { - private readonly IRow _meta; - private readonly string _name; - private readonly ColumnType _type; - - public string Name { get { return _name; } } - public IRow Metadata { get { return _meta; } } - public ColumnType Type { get { return _type; } } + public string Name { get; } + public IRow Metadata { get; } + public ColumnType Type { get; } public abstract bool IsActive { get; } public SimpleColumnBase(string name, IRow meta, ColumnType type) @@ -395,9 +447,9 @@ public SimpleColumnBase(string name, IRow meta, ColumnType type) Contracts.CheckValue(type, nameof(type)); Contracts.CheckParam(type.RawType == typeof(T), nameof(type), "Mismatch between CLR type and column type"); - _name = name; - _meta = meta; - _type = type; + Name = name; + Metadata = meta; + Type = type; } Delegate IColumn.GetGetter() @@ -427,7 +479,7 @@ private sealed class ConstOneImpl : SimpleColumnBase { private readonly T _value; - public override bool IsActive { get { return true; } } + public override bool IsActive => true; public ConstOneImpl(string name, IRow meta, ColumnType type, T value) : base(name, meta, type) @@ -474,7 +526,7 @@ private sealed class GetterImpl : SimpleColumnBase { private readonly ValueGetter _getter; - public override bool IsActive { get { return _getter != null; } } + public override bool IsActive => _getter != null; public GetterImpl(string name, IRow meta, ColumnType type, ValueGetter getter) : base(name, meta, type) @@ -500,9 +552,9 @@ private sealed class RowColumnRow : IRow private readonly IColumn[] _columns; private readonly SchemaImpl _schema; - public ISchema Schema { get { return _schema; } } - public long Position { get { return _counted.Position; } } - public long Batch { get { return _counted.Batch; } } + public ISchema Schema => _schema; + public long Position => _counted.Position; + public long Batch => _counted.Batch; public RowColumnRow(ICounted counted, IColumn[] columns) { @@ -538,7 +590,7 @@ private sealed class SchemaImpl : ISchema private readonly RowColumnRow _parent; private readonly Dictionary _nameToIndex; - public int ColumnCount { get { return _parent._columns.Length; } } + public int ColumnCount => _parent._columns.Length; public SchemaImpl(RowColumnRow parent) { diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs index 49f7f8b99d..e30aa97a42 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs @@ -55,5 +55,4 @@ public CompositeReaderEstimator Append(IEstimator return new CompositeReaderEstimator(_start, _estimatorChain.Append(estimator)); } } - } diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs index 9cb7ec4dab..ecbc28ebdf 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System; using Microsoft.ML.Core.Data; namespace Microsoft.ML.Runtime.Data @@ -74,5 +75,53 @@ public static TransformerChain Append(this ITransformer start, T return new TransformerChain(start, transformer); } + + private sealed class DelegateEstimator : IEstimator + where TTransformer : class, ITransformer + { + private readonly IEstimator _est; + private readonly Action _onFit; + + public DelegateEstimator(IEstimator estimator, Action onFit) + { + Contracts.AssertValue(estimator); + Contracts.AssertValue(onFit); + _est = estimator; + _onFit = onFit; + } + + public TTransformer Fit(IDataView input) + { + var trans = _est.Fit(input); + _onFit(trans); + return trans; + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + => _est.GetOutputSchema(inputSchema); + } + + /// + /// Given an estimator, return a wrapping object that will call a delegate once + /// is called. It is often important for an estimator to return information about what was fit, which is why the + /// method returns a specifically typed object, rather than just a general + /// . However, at the same time, are often formed into pipelines + /// with many objects, so we may need to build a chain of estimators via where the + /// estimator for which we want to get the transformer is buried somewhere in this chain. For that scenario, we can through this + /// method attach a delegate that will be called once fit is called. + /// + /// The type of returned by + /// The estimator to wrap + /// The delegate that is called with the resulting instances once + /// is called. Because + /// may be called multiple times, this delegate may also be called multiple times. + /// A wrapping estimator that calls the indicated delegate whenever fit is called + public static IEstimator WithOnFitDelegate(this IEstimator estimator, Action onFit) + where TTransformer : class, ITransformer + { + Contracts.CheckValue(estimator, nameof(estimator)); + Contracts.CheckValue(onFit, nameof(onFit)); + return new DelegateEstimator(estimator, onFit); + } } } diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs index dc7cb43830..5ccbb8d7e9 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoader.cs @@ -200,16 +200,35 @@ public sealed class Range { public Range() { } + /// + /// A range representing a single value. Will result in a scalar column. + /// + /// The index of the field of the text file to read. public Range(int index) - : this(index, index) { } + { + Contracts.CheckParam(index >= 0, nameof(index), "Must be non-negative"); + Min = index; + Max = index; + } - public Range(int min, int max) + /// + /// A range representing a set of values. Will result in a vector column. + /// + /// The minimum inclusive index of the column. + /// The maximum-inclusive index of the column. If null + /// indicates that the should auto-detect the legnth + /// of the lines, and read till the end. + public Range(int min, int? max) { - Contracts.CheckParam(min >= 0, nameof(min), "min must be non-negative."); - Contracts.CheckParam(max >= min, nameof(max), "max must be greater than or equal to min."); + Contracts.CheckParam(min >= 0, nameof(min), "Must be non-negative"); + Contracts.CheckParam(!(max < min), nameof(max), "If specified, must be greater than or equal to " + nameof(min)); Min = min; Max = max; + // Note that without the following being set, in the case where there is a single range + // where Min == Max, the result will not be a vector valued but a scalar column. + ForceVector = true; + AutoEnd = max == null; } [Argument(ArgumentType.Required, HelpText = "First index in the range")] diff --git a/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs new file mode 100644 index 0000000000..095617deb1 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/Text/TextLoaderStatic.cs @@ -0,0 +1,259 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; + +namespace Microsoft.ML.Runtime.Data +{ + public sealed partial class TextLoader + { + /// + /// Configures a reader for text files. + /// + /// The type shape parameter, which must be + /// + /// The delegate that describes what fields to read from the text file, as well as + /// describing their input type. The way in which it works is that the delegate is fed a , + /// and the user composes a value-tuple with instances out of that . + /// The resulting data will have columns with the names corresponding to their names in the value-tuple. + /// Input files. If null then no files are read, but this means that options or + /// configurations that require input data for initialization (e.g., or + /// ) with a null second argument. + /// Data file has header with feature names. + /// Text field separator. + /// Whether the input -may include quoted values, which can contain separator + /// characters, colons, and distinguish empty values from missing values. When true, consecutive separators + /// denote a missing value and an empty value is denoted by "". When false, consecutive separators + /// denote an empty value. + /// Whether the input may include sparse representations. + /// Remove trailing whitespace from lines. + /// A configured statically-typed reader for text files. + public static DataReader CreateReader<[IsShape] TTupleShape>( + IHostEnvironment env, Func func, IMultiStreamSource files = null, + bool hasHeader = false, char separator = '\t', bool allowQuoting = true, bool allowSparse = true, + bool trimWhitspace = false) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(func, nameof(func)); + env.CheckValueOrNull(files); + + // Populate all args except the columns. + var args = new Arguments(); + args.AllowQuoting = allowQuoting; + args.AllowSparse = allowSparse; + args.HasHeader = hasHeader; + args.SeparatorChars = new[] { separator }; + args.TrimWhitespace = trimWhitspace; + + var rec = new TextReconciler(args, files); + var ctx = new Context(rec); + + using (var ch = env.Start("Initializing " + nameof(TextLoader))) + { + var readerEst = StaticPipeUtils.ReaderEstimatorAnalyzerHelper(env, ch, ctx, rec, func); + Contracts.AssertValue(readerEst); + var reader = readerEst.Fit(files); + ch.Done(); + return reader; + } + } + + private sealed class TextReconciler : ReaderReconciler + { + private readonly Arguments _args; + private readonly IMultiStreamSource _files; + + public TextReconciler(Arguments args, IMultiStreamSource files) + { + Contracts.AssertValue(args); + Contracts.AssertValueOrNull(files); + + _args = args; + _files = files; + } + + public override IDataReaderEstimator> Reconcile( + IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary outputNames) + { + Contracts.AssertValue(env); + Contracts.AssertValue(toOutput); + Contracts.AssertValue(outputNames); + Contracts.Assert(_args.Column == null); + + Column Create(PipelineColumn pipelineCol) + { + var pipelineArgCol = (IPipelineArgColumn)pipelineCol; + var argCol = pipelineArgCol.Create(); + argCol.Name = outputNames[pipelineCol]; + return argCol; + } + + var cols = _args.Column = new Column[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + cols[i] = Create(toOutput[i]); + + var orig = new TextLoader(env, _args, _files); + return new TrivialReaderEstimator(orig); + } + } + + private interface IPipelineArgColumn + { + /// + /// Creates a object corresponding to the , with everything + /// filled in except . + /// + Column Create(); + } + + /// + /// Context object by which a user can indicate what fields they want to read from a text file, and what data type they ought to have. + /// Instances of this class are never made but the user, but rather are fed into the delegate in + /// . + /// + public sealed class Context + { + private readonly Reconciler _rec; + + internal Context(Reconciler rec) + { + Contracts.AssertValue(rec); + _rec = rec; + } + + /// + /// Reads a scalar Boolean column from a single field in the text file. + /// + /// The zero-based index of the field to read from. + /// The column representation. + public Scalar LoadBool(int ordinal) => Load(DataKind.BL, ordinal); + + /// + /// Reads a vector Boolean column from a range of fields in the text file. + /// + /// The zero-based inclusive lower index of the field to read from. + /// The zero-based inclusive upper index of the field to read from. + /// Note that if this is null, it will read to the end of the line. The file(s) + /// will be inspected to get the length of the type. + /// The column representation. + public Vector LoadBool(int minOrdinal, int? maxOrdinal) => Load(DataKind.BL, minOrdinal, maxOrdinal); + + /// + /// Reads a scalar single-precision floating point column from a single field in the text file. + /// + /// The zero-based index of the field to read from. + /// The column representation. + public Scalar LoadFloat(int ordinal) => Load(DataKind.R4, ordinal); + + /// + /// Reads a vector single-precision column from a range of fields in the text file. + /// + /// The zero-based inclusive lower index of the field to read from. + /// The zero-based inclusive upper index of the field to read from. + /// Note that if this is null, it will read to the end of the line. The file(s) + /// will be inspected to get the length of the type. + /// The column representation. + public Vector LoadFloat(int minOrdinal, int? maxOrdinal) => Load(DataKind.R4, minOrdinal, maxOrdinal); + + /// + /// Reads a scalar double-precision floating point column from a single field in the text file. + /// + /// The zero-based index of the field to read from. + /// The column representation. + public Scalar LoadDouble(int ordinal) => Load(DataKind.R8, ordinal); + + /// + /// Reads a vector double-precision column from a range of fields in the text file. + /// + /// The zero-based inclusive lower index of the field to read from. + /// The zero-based inclusive upper index of the field to read from. + /// Note that if this is null, it will read to the end of the line. The file(s) + /// will be inspected to get the length of the type. + /// The column representation. + public Vector LoadDouble(int minOrdinal, int? maxOrdinal) => Load(DataKind.R8, minOrdinal, maxOrdinal); + + /// + /// Reads a scalar textual column from a single field in the text file. + /// + /// The zero-based index of the field to read from. + /// The column representation. + public Scalar LoadText(int ordinal) => Load(DataKind.TX, ordinal); + + /// + /// Reads a vector textual column from a range of fields in the text file. + /// + /// The zero-based inclusive lower index of the field to read from. + /// The zero-based inclusive upper index of the field to read from. + /// Note that if this is null, it will read to the end of the line. The file(s) + /// will be inspected to get the length of the type. + /// The column representation. + public Vector LoadText(int minOrdinal, int? maxOrdinal) => Load(DataKind.TX, minOrdinal, maxOrdinal); + + private Scalar Load(DataKind kind, int ordinal) + { + Contracts.CheckParam(ordinal >= 0, nameof(ordinal), "Should be non-negative"); + return new MyScalar(_rec, kind, ordinal); + } + + private Vector Load(DataKind kind, int minOrdinal, int? maxOrdinal) + { + Contracts.CheckParam(minOrdinal >= 0, nameof(minOrdinal), "Should be non-negative"); + var v = maxOrdinal >= minOrdinal; + Contracts.CheckParam(!(maxOrdinal < minOrdinal), nameof(maxOrdinal), "If specified, cannot be less than " + nameof(minOrdinal)); + return new MyVector(_rec, kind, minOrdinal, maxOrdinal); + } + + private class MyScalar : Scalar, IPipelineArgColumn + { + private readonly DataKind _kind; + private readonly int _ordinal; + + public MyScalar(Reconciler rec, DataKind kind, int ordinal) + : base(rec, null) + { + _kind = kind; + _ordinal = ordinal; + } + + public Column Create() + { + return new Column() + { + Type = _kind, + Source = new[] { new Range(_ordinal) }, + }; + } + } + + private class MyVector : Vector, IPipelineArgColumn + { + private readonly DataKind _kind; + private readonly int _min; + private readonly int? _max; + + public MyVector(Reconciler rec, DataKind kind, int min, int? max) + : base(rec, null) + { + _kind = kind; + _min = min; + _max = max; + } + + public Column Create() + { + return new Column() + { + Type = _kind, + Source = new[] { new Range(_min, _max) }, + }; + } + } + } + } +} + diff --git a/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs index 3bb589191f..506ea8cf73 100644 --- a/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs +++ b/src/Microsoft.ML.Data/DataLoadSave/TrivialReaderEstimator.cs @@ -12,15 +12,15 @@ namespace Microsoft.ML.Runtime.Data public sealed class TrivialReaderEstimator: IDataReaderEstimator where TReader: IDataReader { - private readonly TReader _reader; + public TReader Reader { get; } public TrivialReaderEstimator(TReader reader) { - _reader = reader; + Reader = reader; } - public TReader Fit(TSource input) => _reader; + public TReader Fit(TSource input) => Reader; - public SchemaShape GetOutputSchema() => SchemaShape.Create(_reader.GetOutputSchema()); + public SchemaShape GetOutputSchema() => SchemaShape.Create(Reader.GetOutputSchema()); } } diff --git a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj index 8d5b0fd2d0..2a23a7322c 100644 --- a/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj +++ b/src/Microsoft.ML.Data/Microsoft.ML.Data.csproj @@ -1,4 +1,4 @@ - + netstandard2.0 diff --git a/src/Microsoft.ML.Data/StaticPipe/Attributes.cs b/src/Microsoft.ML.Data/StaticPipe/Attributes.cs new file mode 100644 index 0000000000..6a02b98d1a --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/Attributes.cs @@ -0,0 +1,27 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; + +namespace Microsoft.ML.Data.StaticPipe +{ + /// + /// An indicator to the analyzer that this type parameter ought to be a valid schema-shape object (e.g., a leaf-tuple, or + /// value-tuples of such) as the return type. Note that this attribute is typically only used in situations where a user + /// might be essentially declaring that type, as opposed to using an already established shape type. So: a method that merely + /// takes an already existing typed instance would tend on the other hand to not use this type parameter. To give an example: + /// + /// has the parameter on the new output tuple shape. + /// + /// The cost to not specifying this on such an entry point is that the compile time type-checks on the shape parameters will + /// no longer be enforced, which is suboptimal given that the purpose of the statically typed interfaces is to have compile-time + /// checks. However, it is not disastrous since the runtime checks will still be in effect. + /// + /// User code may use this attribute on their types if they have generic type parameters that interface with this library. + /// + [AttributeUsage(AttributeTargets.GenericParameter)] + public sealed class IsShapeAttribute : Attribute + { + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/DataReader.cs b/src/Microsoft.ML.Data/StaticPipe/DataReader.cs new file mode 100644 index 0000000000..3cf9866509 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/DataReader.cs @@ -0,0 +1,56 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe +{ + public sealed class DataReader : SchemaBearing + { + public IDataReader AsDynamic { get; } + + internal DataReader(IHostEnvironment env, IDataReader reader, StaticSchemaShape shape) + : base(env, shape) + { + Env.AssertValue(reader); + + AsDynamic = reader; + Shape.Check(Env, AsDynamic.GetOutputSchema()); + } + + public DataReaderEstimator> Append(Estimator estimator) + where TTrans : class, ITransformer + { + Contracts.Assert(nameof(Append) == nameof(CompositeReaderEstimator.Append)); + + var readerEst = AsDynamic.Append(estimator.AsDynamic); + return new DataReaderEstimator>(Env, readerEst, estimator.Shape); + } + + public DataReader Append(Transformer transformer) + where TTransformer : class, ITransformer + { + Env.CheckValue(transformer, nameof(transformer)); + Env.Assert(nameof(Append) == nameof(CompositeReaderEstimator.Append)); + + var reader = AsDynamic.Append(transformer.AsDynamic); + return new DataReader(Env, reader, transformer.Shape); + } + + public DataView Read(TIn input) + { + // We cannot check the value of input since it may not be a reference type, and it is not clear + // that there is an absolute case for insisting that the input type be a reference type, and much + // less further that null inputs will never be correct. So we rely on the wrapping object to make + // that determination. + Env.Assert(nameof(Read) == nameof(IDataReader.Read)); + + var data = AsDynamic.Read(input); + return new DataView(Env, data, Shape); + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs b/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs new file mode 100644 index 0000000000..d922c2a677 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/DataReaderEstimator.cs @@ -0,0 +1,43 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe +{ + public sealed class DataReaderEstimator : SchemaBearing + where TDataReader : class, IDataReader + { + public IDataReaderEstimator AsDynamic { get; } + + internal DataReaderEstimator(IHostEnvironment env, IDataReaderEstimator estimator, StaticSchemaShape shape) + : base(env, shape) + { + Env.AssertValue(estimator); + + AsDynamic = estimator; + Shape.Check(Env, AsDynamic.GetOutputSchema()); + } + + public DataReader Fit(TIn input) + { + Contracts.Assert(nameof(Fit) == nameof(IDataReaderEstimator.Fit)); + + var reader = AsDynamic.Fit(input); + return new DataReader(Env, reader, Shape); + } + + public DataReaderEstimator> Append(Estimator est) + where TTrans : class, ITransformer + { + Contracts.Assert(nameof(Append) == nameof(CompositeReaderEstimator.Append)); + + var readerEst = AsDynamic.Append(est.AsDynamic); + return new DataReaderEstimator>(Env, readerEst, est.Shape); + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/DataView.cs b/src/Microsoft.ML.Data/StaticPipe/DataView.cs new file mode 100644 index 0000000000..9c2f3f22c5 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/DataView.cs @@ -0,0 +1,24 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; + +namespace Microsoft.ML.Data.StaticPipe +{ + public class DataView : SchemaBearing + { + public IDataView AsDynamic { get; } + + internal DataView(IHostEnvironment env, IDataView view, StaticSchemaShape shape) + : base(env, shape) + { + Env.AssertValue(view); + + AsDynamic = view; + Shape.Check(Env, AsDynamic.Schema); + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/Estimator.cs b/src/Microsoft.ML.Data/StaticPipe/Estimator.cs new file mode 100644 index 0000000000..28e79712b5 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/Estimator.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe +{ + public sealed class Estimator : SchemaBearing + where TTransformer : class, ITransformer + { + public IEstimator AsDynamic { get; } + private readonly StaticSchemaShape _inShape; + + internal Estimator(IHostEnvironment env, IEstimator estimator, StaticSchemaShape inShape, StaticSchemaShape outShape) + : base(env, outShape) + { + Env.CheckValue(estimator, nameof(estimator)); + AsDynamic = estimator; + _inShape = inShape; + // Our ability to check estimators at constructor time is somewaht limited. During fit though we could. + // Fortunately, estimators are one of the least likely things that users will freqeuently declare the + // types of on their own. + } + + public Transformer Fit(DataView view) + { + Contracts.Assert(nameof(Fit) == nameof(IEstimator.Fit)); + _inShape.Check(Env, view.AsDynamic.Schema); + + var trans = AsDynamic.Fit(view.AsDynamic); + return new Transformer(Env, trans, _inShape, Shape); + } + + public Estimator Append(Estimator estimator) + { + Env.CheckValue(estimator, nameof(estimator)); + + var est = AsDynamic.Append(estimator.AsDynamic); + return new Estimator(Env, est, _inShape, estimator.Shape); + } + + public Estimator Append<[IsShape] TTupleNewOutShape>(Func mapper) + { + Contracts.CheckValue(mapper, nameof(mapper)); + + using (var ch = Env.Start(nameof(Append))) + { + var method = mapper.Method; + + // Construct the dummy column structure, then apply the mapping. + var input = StaticPipeInternalUtils.MakeAnalysisInstance(out var fakeReconciler); + KeyValuePair[] inPairs = StaticPipeInternalUtils.GetNamesValues(input, method.GetParameters()[0]); + + // Initially we suppose we've only assigned names to the inputs. + var inputColToName = new Dictionary(); + foreach (var p in inPairs) + inputColToName[p.Value] = p.Key; + string NameMap(PipelineColumn col) + { + inputColToName.TryGetValue(col, out var val); + return val; + } + + var readerEst = StaticPipeUtils.GeneralFunctionAnalyzer(Env, ch, input, fakeReconciler, mapper, out var estTail, NameMap); + ch.Assert(readerEst == null); + ch.AssertValue(estTail); + + var est = AsDynamic.Append(estTail); + var newOut = StaticSchemaShape.Make(method.ReturnParameter); + var toReturn = new Estimator(Env, est, _inShape, newOut); + ch.Done(); + return toReturn; + } + } + } + + public static class Estimator + { + /// + /// Create an object that can be used as the start of a new pipeline, that assumes it uses + /// something with the sahape of as its input schema shape. + /// The returned object is an empty estimator. + /// + /// Creates a new empty head of a pipeline + /// The empty esitmator, to which new items may be appended to create a pipeline + public static Estimator MakeNew(SchemaBearing fromSchema) + { + Contracts.CheckValue(fromSchema, nameof(fromSchema)); + return fromSchema.MakeNewEstimator(); + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs b/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs new file mode 100644 index 0000000000..dd8ace66b3 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/PipelineColumn.cs @@ -0,0 +1,144 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe.Runtime +{ + /// + /// This class is used as a type marker for producing structures for use in the statically + /// typed columnate pipeline building helper API. Users will not create these structures directly. Rather components + /// will implement (hidden) subclasses of one of this classes subclasses (e.g., , + /// ), which will contain information that the builder API can use to construct an actual + /// sequence of objects. + /// + public abstract class PipelineColumn + { + internal readonly Reconciler ReconcilerObj; + internal readonly PipelineColumn[] Dependencies; + + private protected PipelineColumn(Reconciler reconciler, PipelineColumn[] dependencies) + { + Contracts.CheckValue(reconciler, nameof(reconciler)); + Contracts.CheckValueOrNull(dependencies); + + ReconcilerObj = reconciler; + Dependencies = dependencies; + } + } + + /// + /// For representing a non-key, non-vector . + /// + /// + public abstract class Scalar : PipelineColumn + { + protected Scalar(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(Scalar)}<{typeof(T).Name}>"; + } + + /// + /// For representing a of known length. + /// + /// The vector item type. + public abstract class Vector : PipelineColumn + { + protected Vector(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(Vector)}<{typeof(T).Name}>"; + } + + /// + /// For representing a that is normalized, that is, its + /// value is set with the value true. + /// + /// The vector item type. + public abstract class NormVector : Vector + { + protected NormVector(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(NormVector)}<{typeof(T).Name}>"; + } + + /// + /// For representing a of unknown length. + /// + /// The vector item type. + public abstract class VarVector : PipelineColumn + { + protected VarVector(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(VarVector)}<{typeof(T).Name}>"; + } + + /// + /// For representing a of known cardinality, where the type of key is not specified. + /// + /// The physical type representing the key, which should always be one of , + /// , , or + /// Note that a vector of keys type we would represent as with a + /// type parameter. Note also, if the type of the key is known then that should be represented + /// by . + public abstract class Key : PipelineColumn + { + protected Key(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(Key)}<{typeof(T).Name}>"; + } + + /// + /// For representing a key-type of known cardinality that has key values over a particular type. This is used to + /// represent a where it is known that it will have of a particular type . + /// + /// The physical type representing the key, which should always be one of , + /// , , or + /// The type of values the key-type is enumerating. Commonly this is but + /// this is not necessary + public abstract class Key : Key + { + protected Key(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(Key)}<{typeof(T).Name}, {typeof(TVal).Name}>"; + } + + /// + /// For representing a of unknown cardinality. + /// + /// The physical type representing the key, which should always be one of , + /// , , or + /// Note that unlike the and duality, there is no + /// type corresponding to this type but with key-values, since key-values are necessarily a vector of known + /// size so any enumeration into that set would itself be a key-value of unknown cardinality. + public abstract class VarKey : PipelineColumn + { + protected VarKey(Reconciler reconciler, params PipelineColumn[] dependencies) + : base(reconciler, dependencies) + { + } + + public override string ToString() => $"{nameof(VarKey)}<{typeof(T).Name}>"; + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs b/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs new file mode 100644 index 0000000000..dac94f7a01 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/Reconciler.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe.Runtime +{ + /// + /// An object for instances to indicate to the analysis code for static pipelines that + /// they should be considered a single group of columns (through equality on the reconcilers), as well as how to + /// actually create the underlying dynamic structures, whether an + /// (for the ) or a + /// (for the ). + /// + public abstract class Reconciler + { + private protected Reconciler() { } + } + + /// + /// Reconciler for column groups intended to resolve to a new + /// or . + /// + /// The input type of the + /// object. + public abstract class ReaderReconciler : Reconciler + { + public ReaderReconciler() : base() { } + + /// + /// Returns a data-reader estimator. Note that there are no input names because the columns from a data-reader + /// estimator should have no dependencies. + /// + /// The host environment to use to create the data-reader estimator + /// The columns that the object created by the reconciler should output + /// A map containing + /// + public abstract IDataReaderEstimator> Reconcile( + IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary outputNames); + } + + /// + /// Reconciler for column groups intended to resolve to an . This type of + /// reconciler will work with + /// or other methods that involve the creation of estimator chains. + /// + public abstract class EstimatorReconciler : Reconciler + { + public EstimatorReconciler() : base() { } + + /// + /// Returns an estimator. + /// + /// The host environment to use to create the estimator + /// The columns that the object created by the reconciler should output + /// The name mapping that maps dependencies of the output columns to their names + /// The name mapping that maps the output column to their names + /// While most estimators allow full control over the names of their outputs, a limited + /// subset of estimator transforms do not allow this: they produce columns whose names are unconfigurable. For + /// these, there is this collection which provides the names used by the analysis tool. If the estimator under + /// construction must use one of the names here, then they are responsible for "saving" the column they will + /// overwrite using applications of the . Note that if the estimator under + /// construction has complete control over what columns it produces, there is no need for it to pay this argument + /// any attention. + /// Returns an estimator. + public abstract IEstimator Reconcile( + IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames); + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs b/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs new file mode 100644 index 0000000000..5fb6babbec --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/SchemaAssertionContext.cs @@ -0,0 +1,215 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe.Runtime +{ + /// + /// An object for declaring a schema-shape. This is mostly commonly used in situations where a user is + /// asserting that a dynamic object bears a certain specific static schema. For example: when phrasing + /// the dynamically typed as being a specific . + /// It is never created by the user directly, but instead an instance is typically fed in as an argument + /// to a delegate, and the user will call methods on this context to indicate a certain type is so. + /// + /// + /// All objects are, deliberately, imperitavely useless as they are + /// intended to be used only in a declarative fashion. The methods and properties of this class go one step + /// further and return null for everything with a return type of . + /// + /// Because 's type system is extensible, assemblies that declare their own types + /// should allow users to assert typedness in their types by defining extension methods over this class. + /// However, even failing the provision of such a helper, a user can still provide a workaround by just + /// declaring the type as something like default(Scalar<TheCustomType>, without using the + /// instance of this context. + /// + public sealed class SchemaAssertionContext + { + // Hiding all these behind empty-structures is a bit of a cheap trick, but probably works + // pretty well considering that the alternative is a bunch of tiny objects allocated on the + // stack. Plus, the default value winds up working for them. We can also exploit the `ref struct` + // property of these things to make sure people don't make the mistake of assigning them as the + // values. + + /// Assertions over a column of . + public PrimitiveTypeAssertions I1 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions I2 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions I4 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions I8 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions U1 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions U2 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions U4 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions U8 => default; + + /// Assertions over a column of . + public NormalizableTypeAssertions R4 => default; + + /// Assertions over a column of . + public NormalizableTypeAssertions R8 => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions Text => default; + + /// Assertions over a column of . + public PrimitiveTypeAssertions Bool => default; + + /// Assertions over a column of with . + public KeyTypeSelectorAssertions KeyU1 => default; + /// Assertions over a column of with . + public KeyTypeSelectorAssertions KeyU2 => default; + /// Assertions over a column of with . + public KeyTypeSelectorAssertions KeyU4 => default; + /// Assertions over a column of with . + public KeyTypeSelectorAssertions KeyU8 => default; + + internal static SchemaAssertionContext Inst = new SchemaAssertionContext(); + + private SchemaAssertionContext() { } + + // Until we have some transforms that use them, we might not expect to see too much interest in asserting + // the time relevant datatypes. + + /// + /// Holds assertions relating to the basic primitive types. + /// + public ref struct PrimitiveTypeAssertions + { + private PrimitiveTypeAssertions(int i) { } + + /// + /// Asserts a type that is directly this . + /// + public Scalar Scalar => null; + + /// + /// Asserts a type corresponding to a of this , + /// where is true. + /// + public Vector Vector => null; + + /// + /// Asserts a type corresponding to a of this , + /// where is true. + /// + public VarVector VarVector => null; + } + + public ref struct NormalizableTypeAssertions + { + private NormalizableTypeAssertions(int i) { } + + /// + /// Asserts a type that is directly this . + /// + public Scalar Scalar => null; + + /// + /// Asserts a type corresponding to a of this , + /// where is true. + /// + public Vector Vector => null; + + /// + /// Asserts a type corresponding to a of this , + /// where is true. + /// + public VarVector VarVector => null; + /// + /// Asserts a type corresponding to a of this , + /// where is true, and the + /// metadata is defined with a Boolean true value. + /// + public NormVector NormVector => null; + } + + /// + /// Once a single general key type has been selected, we can select its vector-ness. + /// + /// The static type corresponding to a . + public ref struct KeyTypeVectorAssertions + where T : class + { + private KeyTypeVectorAssertions(int i) { } + + /// + /// Asserts a type that is directly this . + /// + public T Scalar => null; + + /// + /// Asserts a type corresponding to a of this , + /// where is true. + /// + public Vector Vector => null; + + /// + /// Asserts a type corresponding to a of this , + /// where is true. + /// + public VarVector VarVector => null; + } + + /// + /// Assertions for key types of various forms. Used to select a particular . + /// + /// + public ref struct KeyTypeSelectorAssertions + { + private KeyTypeSelectorAssertions(int i) { } + + /// + /// Asserts a type corresponding to a where is positive, that is, is of known cardinality, + /// but that we are not asserting has any particular type of metadata. + /// + public KeyTypeVectorAssertions> NoValue => default; + + /// + /// Asserts a type corresponding to a where is zero, that is, is of unknown cardinality. + /// + public KeyTypeVectorAssertions> UnknownCardinality => default; + + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> I1Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> I2Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> I4Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> I8Values => default; + + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> U1Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> U2Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> U4Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> U8Values => default; + + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> R4Values => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> R8Values => default; + + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> TextValues => default; + /// Asserts a of known cardinality with a vector of metadata. + public KeyTypeVectorAssertions> BoolValues => default; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs b/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs new file mode 100644 index 0000000000..7413ab4764 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/SchemaBearing.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe +{ + /// + /// A base class for the statically-typed pipeline components, that are marked as producing + /// data whose schema has a certain shape. + /// + /// + public abstract class SchemaBearing + { + private protected readonly IHostEnvironment Env; + internal readonly StaticSchemaShape Shape; + + /// + /// Constructor for a block maker. + /// + /// The host environment, stored with this object + /// The item holding the name and types as enumerated within + /// + private protected SchemaBearing(IHostEnvironment env, StaticSchemaShape shape) + { + Contracts.AssertValue(env); + env.AssertValue(shape); + + Env = env; + Shape = shape; + } + + /// + /// Create an object that can be used as the start of a new pipeline, that assumes it uses + /// something with the sahape of as its input schema shape. + /// The returned object is an empty estimator. + /// + internal Estimator MakeNewEstimator() + { + var est = new EstimatorChain(); + return new Estimator(Env, est, Shape, Shape); + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs b/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs new file mode 100644 index 0000000000..d14e0f857a --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/StaticPipeExtensions.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using System; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Core.Data; + +namespace Microsoft.ML.Data.StaticPipe +{ + public static class StaticPipeExtensions + { + /// + /// Asserts that a given data view has the indicated schema. If this method returns without + /// throwing then the view has been validated to have columns with the indicated names and types. + /// + /// The type representing the view's schema shape + /// The view to assert the static schema on + /// The host environment to keep in the statically typed variant + /// The delegate through which we declare the schema, which ought to + /// use the input to declare a + /// of the indices, properly named + /// A statically typed wrapping of the input view + public static DataView AssertStatic<[IsShape] T>(this IDataView view, IHostEnvironment env, + Func outputDecl) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(view, nameof(view)); + env.CheckValue(outputDecl, nameof(outputDecl)); + + // We don't actually need to call the method, it's just there to give the declaration. +#if DEBUG + outputDecl(SchemaAssertionContext.Inst); +#endif + + var schema = StaticSchemaShape.Make(outputDecl.Method.ReturnParameter); + return new DataView(env, view, schema); + } + + public static DataReader AssertStatic(this IDataReader reader, IHostEnvironment env, + Func outputDecl) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(reader, nameof(reader)); + env.CheckValue(outputDecl, nameof(outputDecl)); + + var schema = StaticSchemaShape.Make(outputDecl.Method.ReturnParameter); + return new DataReader(env, reader, schema); + } + + public static DataReaderEstimator AssertStatic( + this IDataReaderEstimator readerEstimator, IHostEnvironment env, + Func outputDecl) + where TReader : class, IDataReader + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(readerEstimator, nameof(readerEstimator)); + env.CheckValue(outputDecl, nameof(outputDecl)); + + var schema = StaticSchemaShape.Make(outputDecl.Method.ReturnParameter); + return new DataReaderEstimator(env, readerEstimator, schema); + } + + public static Transformer AssertStatic<[IsShape] TIn, [IsShape] TOut, TTrans>( + this TTrans transformer, IHostEnvironment env, + Func inputDecl, + Func outputDecl) + where TTrans : class, ITransformer + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(transformer, nameof(transformer)); + env.CheckValue(inputDecl, nameof(inputDecl)); + env.CheckValue(outputDecl, nameof(outputDecl)); + + var inSchema = StaticSchemaShape.Make(inputDecl.Method.ReturnParameter); + var outSchema = StaticSchemaShape.Make(outputDecl.Method.ReturnParameter); + return new Transformer(env, transformer, inSchema, outSchema); + } + + public static Estimator AssertStatic<[IsShape] TIn, [IsShape] TOut, TTrans>( + this IEstimator estimator, IHostEnvironment env, + Func inputDecl, + Func outputDecl) + where TTrans : class, ITransformer + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(estimator, nameof(estimator)); + env.CheckValue(inputDecl, nameof(inputDecl)); + env.CheckValue(outputDecl, nameof(outputDecl)); + + var inSchema = StaticSchemaShape.Make(inputDecl.Method.ReturnParameter); + var outSchema = StaticSchemaShape.Make(outputDecl.Method.ReturnParameter); + return new Estimator(env, estimator, inSchema, outSchema); + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs b/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs new file mode 100644 index 0000000000..54119f0d0a --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/StaticPipeInternalUtils.cs @@ -0,0 +1,486 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Internal.Utilities; + +namespace Microsoft.ML.Data.StaticPipe.Runtime +{ + /// + /// Utility functions useful for the internal implementations of the key pipeline utilities. + /// + internal static class StaticPipeInternalUtils + { + /// + /// Given a type which is a tree with leaves, return an instance of that + /// type which has appropriate instances of that use the returned reconciler. + /// + /// This is a data-reconciler that always reconciles to a null object + /// A type of either or one of the major subclasses + /// (e.g., , , etc.) + /// An instance of where all fields have the provided reconciler + public static T MakeAnalysisInstance(out ReaderReconciler fakeReconciler) + { + var rec = new AnalyzeUtil.Rec(); + fakeReconciler = rec; + return (T)AnalyzeUtil.MakeAnalysisInstanceCore(rec); + } + + private static class AnalyzeUtil + { + public sealed class Rec : ReaderReconciler + { + public Rec() : base() { } + + public override IDataReaderEstimator> Reconcile( + IHostEnvironment env, PipelineColumn[] toOutput, IReadOnlyDictionary outputNames) + { + Contracts.AssertValue(env); + foreach (var col in toOutput) + env.Assert(col.ReconcilerObj == this); + return null; + } + } + + private static Reconciler _reconciler = new Rec(); + + private sealed class AScalar : Scalar { public AScalar(Rec rec) : base(rec, null) { } } + private sealed class AVector : Vector { public AVector(Rec rec) : base(rec, null) { } } + private sealed class ANormVector : NormVector { public ANormVector(Rec rec) : base(rec, null) { } } + private sealed class AVarVector : VarVector { public AVarVector(Rec rec) : base(rec, null) { } } + private sealed class AKey : Key { public AKey(Rec rec) : base(rec, null) { } } + private sealed class AKey : Key { public AKey(Rec rec) : base(rec, null) { } } + private sealed class AVarKey : VarKey { public AVarKey(Rec rec) : base(rec, null) { } } + + private static PipelineColumn MakeScalar(Rec rec) => new AScalar(rec); + private static PipelineColumn MakeVector(Rec rec) => new AVector(rec); + private static PipelineColumn MakeNormVector(Rec rec) => new ANormVector(rec); + private static PipelineColumn MakeVarVector(Rec rec) => new AVarVector(rec); + private static PipelineColumn MakeKey(Rec rec) => new AKey(rec); + private static Key MakeKey(Rec rec) => new AKey(rec); + private static PipelineColumn MakeVarKey(Rec rec) => new AVarKey(rec); + + private static MethodInfo[] _valueTupleCreateMethod = InitValueTupleCreateMethods(); + + private static MethodInfo[] InitValueTupleCreateMethods() + { + const string methodName = nameof(ValueTuple.Create); + var methods = typeof(ValueTuple).GetMethods() + .Where(m => m.Name == methodName && m.ContainsGenericParameters) + .OrderBy(m => m.GetGenericArguments().Length).Take(7) + .Append(typeof(AnalyzeUtil).GetMethod(nameof(UnstructedCreate))).ToArray(); + return methods; + } + + /// + /// Note that we use this instead of + /// for the eight-item because that method will embed the last element into a one-element tuple, + /// which is embedded in the original. The actual physical representation, which is what is relevant here, + /// has no real conveniences around its creation. + /// + public static ValueTuple + UnstructedCreate( + T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, TRest restTuple) + where TRest : struct + { + return new ValueTuple(v1, v2, v3, v4, v5, v6, v7, restTuple); + } + + public static object MakeAnalysisInstanceCore(Rec rec) + { + var t = typeof(T); + if (typeof(PipelineColumn).IsAssignableFrom(t)) + { + if (t.IsGenericType) + { + var genP = t.GetGenericArguments(); + var genT = t.GetGenericTypeDefinition(); + + if (genT == typeof(Scalar<>)) + return Utils.MarshalInvoke(MakeScalar, genP[0], rec); + if (genT == typeof(Vector<>)) + return Utils.MarshalInvoke(MakeVector, genP[0], rec); + if (genT == typeof(NormVector<>)) + return Utils.MarshalInvoke(MakeNormVector, genP[0], rec); + if (genT == typeof(VarVector<>)) + return Utils.MarshalInvoke(MakeVarVector, genP[0], rec); + if (genT == typeof(Key<>)) + return Utils.MarshalInvoke(MakeKey, genP[0], rec); + if (genT == typeof(Key<,>)) + { + Func f = MakeKey; + return f.Method.GetGenericMethodDefinition().MakeGenericMethod(genP).Invoke(null, new object[] { rec }); + } + if (genT == typeof(VarKey<>)) + return Utils.MarshalInvoke(MakeVector, genP[0], rec); + } + throw Contracts.Except($"Type {t} is a {nameof(PipelineColumn)} yet does not appear to be directly one of " + + $"the official types. This is commonly due to a mistake by the component author and can be addressed by " + + $"upcasting the instance in the tuple definition to one of the official types."); + } + // If it's not a pipeline column type then we suppose it is a value tuple. + + if (t.IsGenericType && ValueTupleUtils.IsValueTuple(t)) + { + var genT = t.GetGenericTypeDefinition(); + var genP = t.GetGenericArguments(); + if (1 <= genP.Length && genP.Length <= 8) + { + // First recursively create the sub-analysis objects. + object[] subArgs = genP.Select(subType => Utils.MarshalInvoke(MakeAnalysisInstanceCore, subType, rec)).ToArray(); + // Next create the tuple. + return _valueTupleCreateMethod[subArgs.Length - 1].MakeGenericMethod(genP).Invoke(null, subArgs); + } + } + throw Contracts.Except($"Type {t} is neither a {nameof(PipelineColumn)} subclass nor a value tuple. Other types are not permitted."); + } + } + + public static KeyValuePair[] GetNamesTypes(ParameterInfo pInfo) + => GetNamesTypes(pInfo); + + public static KeyValuePair[] GetNamesTypes(ParameterInfo pInfo) + { + Contracts.CheckValue(pInfo, nameof(pInfo)); + if (typeof(T) != pInfo.ParameterType) + throw Contracts.ExceptParam(nameof(pInfo), "Type mismatch with " + typeof(T).Name); + var result = NameUtil.GetNames(default, pInfo); + var retVal = new KeyValuePair[result.Length]; + for (int i = 0; i < result.Length; ++i) + { + retVal[i] = new KeyValuePair(result[i].name, result[i].type); + Contracts.Assert(result[i].value == default); + } + return retVal; + } + + public static KeyValuePair[] GetNamesValues(T record, ParameterInfo pInfo) + => GetNamesValues(record, pInfo); + + private static KeyValuePair[] GetNamesValues(T record, ParameterInfo pInfo) + { + Contracts.CheckValue(pInfo, nameof(pInfo)); + Contracts.CheckParam(typeof(T) == pInfo.ParameterType, nameof(pInfo), "Type mismatch with " + nameof(record)); + var result = NameUtil.GetNames(record, pInfo); + var retVal = new KeyValuePair[result.Length]; + for (int i = 0; i < result.Length; ++i) + retVal[i] = new KeyValuePair(result[i].name, result[i].value); + return retVal; + } + + /// + /// A sort of extended version of that accounts + /// for the presence of the , and types. /> + /// + /// Can we assign to this type? + /// From that type? + /// + public static bool IsAssignableFromStaticPipeline(this Type to, Type from) + { + Contracts.AssertValue(to); + Contracts.AssertValue(from); + if (to.IsAssignableFrom(from)) + return true; + // The only exception to the above test are the vector types. These are generic types. + if (!to.IsGenericType || !from.IsGenericType) + return false; + var gto = to.GetGenericTypeDefinition(); + var gfrom = from.GetGenericTypeDefinition(); + + // If either of the types is not one of the vector types, we can just stop right here. + if ((gto != typeof(Vector<>) && gto != typeof(VarVector<>) && gto != typeof(NormVector<>)) || + (gfrom != typeof(Vector<>) && gfrom != typeof(VarVector<>) && gfrom != typeof(NormVector<>))) + { + return false; + } + + // First check the value types. If those don't match, no sense going any further. + var ato = to.GetGenericArguments(); + var afrom = from.GetGenericArguments(); + Contracts.Assert(Utils.Size(ato) == 1); + Contracts.Assert(Utils.Size(afrom) == 1); + + if (!ato[0].IsAssignableFrom(afrom[0])) + return false; + + // We have now confirmed at least the compatibility of the item types. Next we must confirm the same of the vector type. + // Variable sized vectors must match in their types, norm vector can be considered assignable to vector. + + // If either is a var vector, the other must be as well. + if (gto == typeof(VarVector<>)) + return gfrom == typeof(VarVector<>); + + // We can assign from NormVector<> to Vector<>, but not the other way around. So we only fail if we are trying to assign Vector<> to NormVector<>. + return gfrom != typeof(Vector<>) || gto != typeof(NormVector<>); + } + + /// + /// Utility for extracting names out of value-tuple tree structures. + /// + /// + private static class NameUtil + { + private struct Info + { + public readonly Type Type; + public readonly object Item; + + public Info(Type type, object item) + { + Type = type; + Item = item; + } + } + + /// + /// A utility for exacting name/type/value triples out of a value-tuple based tree structure. + /// + /// For example: If were then the value-tuple + /// (a: 1, b: (c: 2, d: 3), e: 4) would result in the return array where the name/value + /// pairs were [("a", 1), ("b.c", 2), ("b.d", 3), "e", 4], in some order, and the type + /// is typeof(int). + /// + /// Note that the type returned in the triple is the type as declared in the tuple, which will + /// be a derived type of , and in turn the type of the value will be + /// of a type derived from that type. + /// + /// This method will throw if anything other than value-tuples or + /// instances are detected during its execution. + /// + /// The type to extract on. + /// The instance to extract values out of. + /// A type parameter associated with this, usually extracted out of some + /// delegate over this value tuple type. Note that names in value-tuples are an illusion perpetrated + /// by the C# compiler, and are not accessible though by reflection, which + /// is why it is necessary to engage in trickery like passing in a delegate over those types, which + /// does retain the information on the names. + /// The list of name/type/value triples extracted out of the tree like-structure + public static (string name, Type type, TLeaf value)[] GetNames(T record, ParameterInfo pInfo) + { + Contracts.AssertValue(pInfo); + Contracts.Assert(typeof(T) == pInfo.ParameterType); + // Record can only be null if it isn't the value tuple type. + + if (typeof(TLeaf).IsAssignableFrom(typeof(T))) + return new[] { ("Data", typeof(T), (TLeaf)(object)record) }; + + // The structure of names for value tuples is somewhat unusual. All names in a nested structure of value + // tuples is arranged in a roughly depth-first structure, unless we consider tuple cardinality greater + // than seven (which is physically stored in a tuple of cardinality eight, with the so-called `Rest` + // field iteratively holding "more" values. So what appears to be a ten-tuple is really an eight-tuple, + // with the first seven items holding the first seven items of the original tuple, and another value + // tuple in `Rest` holding the remaining three items. + + // Anyway: the names are given in depth-first fashion with all items in a tuple being assigned + // contiguously to the items (so for any n-tuple, there is an contiguous n-length segment in the names + // array corresponding to the names). This also applies to the "virtual" >7 tuples, which are for this + // purpose considered "one" tuple, which has some interesting implications on valid traversals of the + // structure. + + var tupleNames = pInfo.GetCustomAttribute()?.TransformNames; + var accumulated = new List<(string, Type, TLeaf)>(); + RecurseNames(record, tupleNames, 0, null, accumulated); + return accumulated.ToArray(); + } + + /// + /// Helper method for , that given a + /// will either append triples to (if the item is of type + /// ), or recurse on this function (if the item is a ), + /// or otherwise throw an error. + /// + /// The type we are recursing on, should be a of some sort + /// The we are extracting on. Note that this is + /// just for the sake of ease of using + /// . + /// The names list extracted from the attribute, or null + /// if no such attribute could be found. + /// The offset into where 's names begin. + /// null for the root level structure, or the appendation of . suffixed names + /// of the path of value-tuples down to this item. + /// The list into which the names are being added + /// The total number of items added to + private static int RecurseNames(object record, IList names, int namesOffset, string namePrefix, List<(string, Type, TLeaf)> accum) + { + if (!ValueTupleUtils.IsValueTuple(typeof(T))) + { + throw Contracts.Except($"Expected to find structure composed of {typeof(ValueTuple)} and {typeof(TLeaf)} " + + $" but during traversal of the structure an item of {typeof(T)} was found instead"); + } + Contracts.AssertValue(record); + Contracts.Assert(record is T); + Contracts.AssertValueOrNull(names); + Contracts.Assert(names == null || namesOffset <= names.Count); + Contracts.AssertValueOrNull(namePrefix); + Contracts.AssertValue(accum); + + var tupleItems = new List(); + + ValueTupleUtils.ApplyActionToTuple((T)record, (index, type, item) + => tupleItems.Add(new Info(type, item))); + int total = tupleItems.Count; + + for (int i = 0; i < tupleItems.Count; ++i) + { + string name = names?[namesOffset + i] ?? $"Item{i + 1}"; + if (!string.IsNullOrEmpty(namePrefix)) + name = namePrefix + name; + + if (typeof(TLeaf).IsAssignableFrom(tupleItems[i].Type)) + accum.Add((name, tupleItems[i].Type, (TLeaf)tupleItems[i].Item)); + else + { + total += Utils.MarshalInvoke(RecurseNames, tupleItems[i].Type, + tupleItems[i].Item, names, namesOffset + total, name + ".", accum); + } + } + + return total; + } + } + + private static class ValueTupleUtils + { + public static bool IsValueTuple(Type t) + { + Type genT = t.IsGenericType ? t.GetGenericTypeDefinition() : t; + return genT == typeof(ValueTuple<>) || genT == typeof(ValueTuple<,>) || genT == typeof(ValueTuple<,,>) + || genT == typeof(ValueTuple<,,,>) || genT == typeof(ValueTuple<,,,,>) || genT == typeof(ValueTuple<,,,,,>) + || genT == typeof(ValueTuple<,,,,,,>) || genT == typeof(ValueTuple<,,,,,,,>); + } + + public delegate void TupleItemAction(int index, Type itemType, object item); + + public static void ApplyActionToTuple(T tuple, TupleItemAction action) + { + Contracts.CheckValue(action, nameof(action)); + ApplyActionToTuple(tuple, 0, action); + } + + internal static void ApplyActionToTuple(object tuple, int root, TupleItemAction action) + { + Contracts.AssertValue(action); + Contracts.Assert(root >= 0); + + var tType = typeof(T); + if (tType.IsGenericType) + tType = tType.GetGenericTypeDefinition(); + + if (typeof(ValueTuple<>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,,>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,,,>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,,,,>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,,,,,>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,,,,,,>) == tType) + MarshalInvoke>(Process, tuple, root, action); + else if (typeof(ValueTuple<,,,,,,,>) == tType) + MarshalInvoke>>(Process, tuple, root, action); + else + { + // This will fall through here if this was either not a generic type or is a value tuple type. + throw Contracts.ExceptParam(nameof(tuple), $"Item should have been a {nameof(ValueTuple)} but was instead {tType}"); + } + } + + private delegate void Processor(T val, int root, TupleItemAction action); + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + action(root++, typeof(T3), val.Item3); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + action(root++, typeof(T3), val.Item3); + action(root++, typeof(T4), val.Item4); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + action(root++, typeof(T3), val.Item3); + action(root++, typeof(T4), val.Item4); + action(root++, typeof(T5), val.Item5); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + action(root++, typeof(T3), val.Item3); + action(root++, typeof(T4), val.Item4); + action(root++, typeof(T5), val.Item5); + action(root++, typeof(T6), val.Item6); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + action(root++, typeof(T3), val.Item3); + action(root++, typeof(T4), val.Item4); + action(root++, typeof(T5), val.Item5); + action(root++, typeof(T6), val.Item6); + action(root++, typeof(T7), val.Item7); + } + + private static void Process(ValueTuple val, int root, TupleItemAction action) + where TRest : struct + { + action(root++, typeof(T1), val.Item1); + action(root++, typeof(T2), val.Item2); + action(root++, typeof(T3), val.Item3); + action(root++, typeof(T4), val.Item4); + action(root++, typeof(T5), val.Item5); + action(root++, typeof(T6), val.Item6); + action(root++, typeof(T7), val.Item7); + ApplyActionToTuple(val.Rest, root++, action); + } + + private static void MarshalInvoke(Processor del, object arg, int root, TupleItemAction action) + { + Contracts.AssertValue(del); + Contracts.Assert(del.Method.IsGenericMethod); + var argType = arg.GetType(); + Contracts.Assert(argType.IsGenericType); + var argGenTypes = argType.GetGenericArguments(); + // The argument generic types should be compatible with the delegate's generic types. + Contracts.Assert(del.Method.GetGenericArguments().Length == argGenTypes.Length); + // Reconstruct the delegate generic types so it adheres to the args generic types. + var newDel = del.Method.GetGenericMethodDefinition().MakeGenericMethod(argGenTypes); + + var result = newDel.Invoke(null, new object[] { arg, root, action }); + } + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs b/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs new file mode 100644 index 0000000000..6da97abe8a --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/StaticPipeUtils.cs @@ -0,0 +1,370 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe.Runtime +{ + /// + /// Utility methods for components that want to expose themselves in the idioms of the statically-typed pipelines. + /// These utilities are meant to be called by and useful to component authors, not users of those components. + /// + public static class StaticPipeUtils + { + /// + /// This is a utility method intended to be used by authors of components to provide a strongly typed . + /// This analysis tool provides a standard way for readers to exploit statically typed pipelines with the + /// standard tuple-shape objects without having to write such code themselves. + /// + /// Estimators will be instantiated with this environment + /// /// Some minor debugging information will be passed along to this channel + /// The input that will be used when invoking , which is used + /// either to produce the input columns. + /// All columns that are yielded by should produce this + /// single reconciler. The analysis code in this method will ensure that this is the first object to be + /// reconciled, before all others. + /// The user provided delegate. + /// The type parameter for the input type to the data reader estimator. + /// The input type of the input delegate. This might be some object out of + /// which one can fetch or else retrieve + /// + /// + public static DataReaderEstimator> + ReaderEstimatorAnalyzerHelper( + IHostEnvironment env, + IChannel ch, + TDelegateInput input, + ReaderReconciler baseReconciler, + Func mapper) + { + var readerEstimator = GeneralFunctionAnalyzer(env, ch, input, baseReconciler, mapper, out var est, col => null); + var schema = StaticSchemaShape.Make(mapper.Method.ReturnParameter); + return new DataReaderEstimator>(env, readerEstimator, schema); + } + + internal static IDataReaderEstimator> + GeneralFunctionAnalyzer( + IHostEnvironment env, + IChannel ch, + TDelegateInput input, + ReaderReconciler baseReconciler, + Func mapper, + out IEstimator estimator, + Func inputNameFunction) + { + Contracts.CheckValue(mapper, nameof(mapper)); + + var method = mapper.Method; + var output = mapper(input); + + KeyValuePair[] outPairs = StaticPipeInternalUtils.GetNamesValues(output, method.ReturnParameter); + + // Map where the key depends on the set of things in the value. The value contains the yet unresolved dependencies. + var keyDependsOn = new Dictionary>(); + // Map where the set of things in the value depend on the key. + var dependsOnKey = new Dictionary>(); + // The set of columns detected with zero dependencies. + var zeroDependencies = new List(); + + // First we build up the two structures above, using a queue and visiting from the outputs up. + var toVisit = new Queue(outPairs.Select(p => p.Value)); + while (toVisit.Count > 0) + { + var col = toVisit.Dequeue(); + ch.CheckParam(col != null, nameof(mapper), "The delegate seems to have null columns returned somewhere in the pipe"); + if (keyDependsOn.ContainsKey(col)) + continue; // Already visited. + + var dependsOn = new HashSet(); + foreach (var dep in col.Dependencies ?? Enumerable.Empty()) + { + dependsOn.Add(dep); + if (!dependsOnKey.TryGetValue(dep, out var dependsOnDep)) + { + dependsOnKey[dep] = dependsOnDep = new HashSet(); + toVisit.Enqueue(dep); + } + dependsOnDep.Add(col); + } + keyDependsOn[col] = dependsOn; + if (dependsOn.Count == 0) + zeroDependencies.Add(col); + } + + // Get the base input columns. + var baseInputs = keyDependsOn.Select(p => p.Key).Where(col => col.ReconcilerObj == baseReconciler).ToArray(); + + // The columns that utilize the base reconciler should have no dependencies. This could only happen if + // the caller of this function has introduced a situation whereby they are claiming they can reconcile + // to a data-reader object but still have input data dependencies, which does not make sense and + // indicates that there is a bug in that component code. Unfortunately we can only detect that condition, + // not determine exactly how it arose, but we can still do so to indicate to the user that there is a + // problem somewhere in the stack. + ch.CheckParam(baseInputs.All(col => keyDependsOn[col].Count == 0), + nameof(input), "Bug detected where column producing object was yielding columns with dependencies."); + + // This holds the mappings of columns to names and back. Note that while the same column could be used on + // the *output*, e.g., you could hypothetically have `(a: r.Foo, b: r.Foo)`, we treat that as the last thing + // that is done. + var nameMap = new BidirectionalDictionary(); + + // Check to see if we have any set of initial names. This is important in the case where we are mapping + // in an input data view. + foreach (var col in baseInputs) + { + string inputName = inputNameFunction(col); + if (inputName != null) + { + ch.Assert(!nameMap.ContainsKey(col)); + ch.Assert(!nameMap.ContainsKey(inputName)); + nameMap[col] = inputName; + + ch.Trace($"Using input with name {inputName}"); + } + } + + estimator = null; + var toCopy = new List<(string src, string dst)>(); + + int tempNum = 0; + // For all outputs, get potential name collisions with used inputs. Resolve by assigning the input a temporary name. + foreach (var p in outPairs) + { + // If the name for the output is already used by one of the inputs, and this output column does not + // happen to have the same name, then we need to rename that input to keep it available. + if (nameMap.TryGetValue(p.Key, out var inputCol) && p.Value != inputCol) + { + ch.Assert(baseInputs.Contains(inputCol)); + string tempName = $"#Temp_{tempNum++}"; + ch.Trace($"Input/output name collision: Renaming '{p.Key}' to '{tempName}'"); + toCopy.Add((p.Key, tempName)); + nameMap[tempName] = nameMap[p.Key]; + ch.Assert(!nameMap.ContainsKey(p.Key)); + } + // If we already have a name for this output column, maybe it is used elsewhere. (This can happen when + // the only thing done with an input is we rename it, or output it twice, or something like this.) In + // this case it is most appropriate to delay renaming till after all other processing has been done in + // that case. But otherwise we may as well just take the name. + if (!nameMap.ContainsKey(p.Value)) + nameMap[p.Key] = p.Value; + } + + // If any renamings were necessary, create the CopyColumns estimator. + if (toCopy.Count > 0) + estimator = new CopyColumnsEstimator(env, toCopy.ToArray()); + + // First clear the inputs from zero-dependencies yet to be resolved. + foreach (var col in baseInputs) + { + ch.Assert(zeroDependencies.Contains(col)); + ch.Assert(col.ReconcilerObj == baseReconciler); + + zeroDependencies.Remove(col); // Make more efficient... + if (!dependsOnKey.TryGetValue(col, out var depends)) + continue; + // If any of these base inputs do not have names because, for example, they do not directly appear + // in the outputs and otherwise do not have names, assign them a name. + if (!nameMap.ContainsKey(col)) + nameMap[col] = $"Temp_{tempNum++}"; + + foreach (var depender in depends) + { + var dependencies = keyDependsOn[depender]; + ch.Assert(dependencies.Contains(col)); + dependencies.Remove(col); + if (dependencies.Count == 0) + zeroDependencies.Add(depender); + } + dependsOnKey.Remove(col); + } + + // Call the reconciler to get the base reader estimator. + var readerEstimator = baseReconciler.Reconcile(env, baseInputs, nameMap.AsOther(baseInputs)); + ch.AssertValueOrNull(readerEstimator); + + // Next we iteratively find those columns with zero dependencies, "create" them, and if anything depends on + // these add them to the collection of zero dependencies, etc. etc. + while (zeroDependencies.Count > 0) + { + // All columns with the same reconciler can be transformed together. + + // Note that the following policy of just taking the first group is not optimal. So for example, we + // could have three columns, (a, b, c). If we had the output (a.X(), b.X() c.Y().X()), then maybe we'd + // reconcile a.X() and b.X() together, then reconcile c.Y(), then reconcile c.Y().X() alone. Whereas, we + // could have reconciled c.Y() first, then reconciled a.X(), b.X(), and c.Y().X() together. + var group = zeroDependencies.GroupBy(p => p.ReconcilerObj).First(); + // Beyond that first group that *might* be a data reader reconciler, all subsequent operations will + // be on where the data is already loaded and so accept data as an input, that is, they should produce + // an estimator. If this is not the case something seriously wonky is going on, most probably that the + // user tried to use a column from another source. If this is detected we can produce a sensible error + // message to tell them not to do this. + if (!(group.Key is EstimatorReconciler rec)) + { + throw ch.Except("Columns from multiple sources were detected. " + + "Did the caller use a " + nameof(PipelineColumn) + " from another delegate?"); + } + PipelineColumn[] cols = group.ToArray(); + // All dependencies should, by this time, have names. + ch.Assert(cols.SelectMany(c => c.Dependencies).All(dep => nameMap.ContainsKey(dep))); + foreach (var newCol in cols) + { + if (!nameMap.ContainsKey(newCol)) + nameMap[newCol] = $"#Temp_{tempNum++}"; + + } + + var localInputNames = nameMap.AsOther(cols.SelectMany(c => c.Dependencies ?? Enumerable.Empty())); + var localOutputNames = nameMap.AsOther(cols); + var usedNames = new HashSet(nameMap.Keys1.Except(localOutputNames.Values)); + + var localEstimator = rec.Reconcile(env, cols, localInputNames, localOutputNames, usedNames); + readerEstimator = readerEstimator?.Append(localEstimator); + estimator = estimator?.Append(localEstimator) ?? localEstimator; + + foreach (var newCol in cols) + { + zeroDependencies.Remove(newCol); // Make more efficient!! + + // Finally, we find all columns that depend on this one. If this happened to be the last pending + // dependency, then we add it to the list. + if (dependsOnKey.TryGetValue(newCol, out var depends)) + { + foreach (var depender in depends) + { + var dependencies = keyDependsOn[depender]; + Contracts.Assert(dependencies.Contains(newCol)); + dependencies.Remove(newCol); + if (dependencies.Count == 0) + zeroDependencies.Add(depender); + } + dependsOnKey.Remove(newCol); + } + } + } + + if (keyDependsOn.Any(p => p.Value.Count > 0)) + { + // This might happen if the user does something incredibly strange, like, say, take some prior + // lambda, assign a column to a local variable, then re-use it downstream in a different lambdas. + // The user would have to go to some extraorindary effort to do that, but nonetheless we want to + // fail with a semi-sensible error message. + throw ch.Except("There were some leftover columns with unresolved dependencies. " + + "Did the caller use a " + nameof(PipelineColumn) + " from another delegate?"); + } + + // Now do the final renaming, if any is necessary. + toCopy.Clear(); + foreach (var p in outPairs) + { + // TODO: Right now we just write stuff out. Once the copy-columns estimator is in place + // we ought to do this for real. + Contracts.Assert(nameMap.ContainsKey(p.Value)); + string currentName = nameMap[p.Value]; + if (currentName != p.Key) + { + ch.Trace($"Will copy '{currentName}' to '{p.Key}'"); + toCopy.Add((currentName, p.Key)); + } + } + + // If any final renamings were necessary, insert the appropriate CopyColumns transform. + if (toCopy.Count > 0) + { + var copyEstimator = new CopyColumnsEstimator(env, toCopy.ToArray()); + if (estimator == null) + estimator = copyEstimator; + else + estimator = estimator.Append(copyEstimator); + } + + ch.Trace($"Exiting {nameof(ReaderEstimatorAnalyzerHelper)}"); + + return readerEstimator; + } + + private sealed class BidirectionalDictionary + { + private readonly Dictionary _d12; + private readonly Dictionary _d21; + + public BidirectionalDictionary() + { + _d12 = new Dictionary(); + _d21 = new Dictionary(); + } + + public bool ContainsKey(T1 k) => _d12.ContainsKey(k); + public bool ContainsKey(T2 k) => _d21.ContainsKey(k); + + public IEnumerable Keys1 => _d12.Keys; + public IEnumerable Keys2 => _d21.Keys; + + public bool TryGetValue(T1 k, out T2 v) => _d12.TryGetValue(k, out v); + public bool TryGetValue(T2 k, out T1 v) => _d21.TryGetValue(k, out v); + + public T1 this[T2 key] + { + get => _d21[key]; + set + { + Contracts.CheckValue((object)key, nameof(key)); + Contracts.CheckValue((object)value, nameof(value)); + + bool removeOldKey = _d12.TryGetValue(value, out var oldKey); + if (_d21.TryGetValue(key, out var oldValue)) + _d12.Remove(oldValue); + if (removeOldKey) + _d21.Remove(oldKey); + + _d12[value] = key; + _d21[key] = value; + Contracts.Assert(_d12.Count == _d21.Count); + } + } + + public T2 this[T1 key] + { + get => _d12[key]; + set + { + Contracts.CheckValue((object)key, nameof(key)); + Contracts.CheckValue((object)value, nameof(value)); + + bool removeOldKey = _d21.TryGetValue(value, out var oldKey); + if (_d12.TryGetValue(key, out var oldValue)) + _d21.Remove(oldValue); + if (removeOldKey) + _d12.Remove(oldKey); + + _d21[value] = key; + _d12[key] = value; + + Contracts.Assert(_d12.Count == _d21.Count); + } + } + + public IReadOnlyDictionary AsOther(IEnumerable keys) + { + Dictionary d = new Dictionary(); + foreach (var v in keys) + d[v] = _d12[v]; + return d; + } + + public IReadOnlyDictionary AsOther(IEnumerable keys) + { + Dictionary d = new Dictionary(); + foreach (var v in keys) + d[v] = _d21[v]; + return d; + } + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs new file mode 100644 index 0000000000..7b257f7b4e --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/StaticSchemaShape.cs @@ -0,0 +1,361 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.Data.StaticPipe.Runtime +{ + /// + /// A schema shape with names corresponding to a type parameter in one of the typed variants + /// of the data pipeline structures. Instances of this class tend to be bundled with the statically + /// typed variants of the dynamic structures (e.g., and so forth), + /// and their primary purpose is to ensure that the schemas of the dynamic structures and the + /// statically declared structures are compatible. + /// + internal sealed class StaticSchemaShape + { + /// + /// The enumeration of name/type pairs. Do not modify. + /// + public readonly KeyValuePair[] Pairs; + + private StaticSchemaShape(KeyValuePair[] pairs) + { + Contracts.AssertValue(pairs); + Pairs = pairs; + } + + /// + /// Creates a new instance out of a parameter info, presumably fetched from a user specified delegate. + /// + /// The static tuple-shape type + /// The parameter info on the method, whose type should be + /// + /// A new instance with names and members types enumerated + public static StaticSchemaShape Make(ParameterInfo info) + { + Contracts.AssertValue(info); + var pairs = StaticPipeInternalUtils.GetNamesTypes(info); + return new StaticSchemaShape(pairs); + } + + /// + /// Checks whether this object is consistent with an actual schema from a dynamic object, + /// throwing exceptions if not. + /// + /// The context on which to throw exceptions + /// The schema to check + public void Check(IExceptionContext ectx, ISchema schema) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(schema); + + foreach (var pair in Pairs) + { + if (!schema.TryGetColumnIndex(pair.Key, out int colIdx)) + throw ectx.ExceptParam(nameof(schema), $"Column named '{pair.Key}' was not found"); + var col = RowColumnUtils.GetColumn(schema, colIdx); + var type = GetTypeOrNull(col); + if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value))) + { + // When not null, we can use IsAssignableFrom to indicate we could assign to this, so as to allow + // for example Key to be considered to be compatible with Key. + + // In the null case, while we cannot directly verify an unrecognized type, we can at least verify + // that the statically declared type should not have corresponded to a recognized type. + if (!pair.Value.IsAssignableFromStaticPipeline(type)) + { + throw ectx.ExceptParam(nameof(schema), + $"Column '{pair.Key}' of type '{col.Type}' cannot be expressed statically as type '{pair.Value}'."); + } + } + } + } + + /// + /// Checks whether this object is consistent with an actual schema shape from a dynamic object, + /// throwing exceptions if not. + /// + /// The context on which to throw exceptions + /// The schema shape to check + public void Check(IExceptionContext ectx, SchemaShape shape) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(shape); + + foreach (var pair in Pairs) + { + var col = shape.FindColumn(pair.Key); + if (col == null) + throw ectx.ExceptParam(nameof(shape), $"Column named '{pair.Key}' was not found"); + var type = GetTypeOrNull(col); + if ((type != null && !pair.Value.IsAssignableFromStaticPipeline(type)) || (type == null && IsStandard(ectx, pair.Value))) + { + // When not null, we can use IsAssignableFrom to indicate we could assign to this, so as to allow + // for example Key to be considered to be compatible with Key. + + // In the null case, while we cannot directly verify an unrecognized type, we can at least verify + // that the statically declared type should not have corresponded to a recognized type. + if (!pair.Value.IsAssignableFromStaticPipeline(type)) + { + // This is generally an error, unless it's the situation where the asserted type is Key<,> but we could + // only resolve it so far as Key<>, since for the moment the SchemaShape cannot determine the type of key + // value metadata. In which case, we can check if the declared type is a subtype of the key that was determined + // from the analysis. + if (pair.Value.IsGenericType && pair.Value.GetGenericTypeDefinition() == typeof(Key<,>) && + type.IsAssignableFromStaticPipeline(pair.Value)) + { + continue; + } + throw ectx.ExceptParam(nameof(shape), + $"Column '{pair.Key}' of type '{col.GetTypeString()}' cannot be expressed statically as type '{pair.Value}'."); + } + } + } + } + + private static Type GetTypeOrNull(SchemaShape.Column col) + { + Contracts.AssertValue(col); + + Type vecType = null; + switch (col.Kind) + { + case SchemaShape.Column.VectorKind.Scalar: + break; // Keep it null. + case SchemaShape.Column.VectorKind.Vector: + // Assume that if the normalized metadata is indicated by the schema shape, it is bool and true. + vecType = col.MetadataKinds.Contains(MetadataUtils.Kinds.IsNormalized) ? typeof(NormVector<>) : typeof(Vector<>); + break; + case SchemaShape.Column.VectorKind.VariableVector: + vecType = typeof(VarVector<>); + break; + default: + // Not recognized. Not necessarily an error of the user, may just indicate this code ought to be updated. + Contracts.Assert(false); + return null; + } + + if (col.IsKey) + { + Type physType = StaticKind(col.ItemType.RawKind); + Contracts.Assert(physType == typeof(byte) || physType == typeof(ushort) + || physType == typeof(uint) || physType == typeof(ulong)); + // As of the time of this writing we cannot distinguish between multiple types of key value metadata, + // so, we don't try. This is tracked in this issue: https://github.com/dotnet/machinelearning/issues/755. + // Because Key<,> descends from Key<> the check will still work. Also the idiom here has no way of + // representing variable size keys. + var keyType = typeof(Key<>).MakeGenericType(physType); + return vecType?.MakeGenericType(keyType) ?? keyType; + } + + if (col.ItemType is PrimitiveType pt) + { + Type physType = StaticKind(pt.RawKind); + // Though I am unaware of any existing instances, it is theoretically possible for a + // primitive type to exist, have the same data kind as one of the existing types, and yet + // not be one of the built in types. (E.g., an outside analogy to the key types.) For this + // reason, we must be certain that when we return here we are covering one fo the builtin types. + if (physType != null && ( + pt == NumberType.I1 || pt == NumberType.I2 || pt == NumberType.I4 || pt == NumberType.I4 || + pt == NumberType.U1 || pt == NumberType.U2 || pt == NumberType.U4 || pt == NumberType.U4 || + pt == NumberType.R4 || pt == NumberType.R8 || pt == NumberType.UG || pt == BoolType.Instance || + pt == DateTimeType.Instance || pt == DateTimeZoneType.Instance || pt == TimeSpanType.Instance || + pt == TextType.Instance)) + { + return (vecType ?? typeof(Scalar<>)).MakeGenericType(physType); + } + } + + return null; + } + + /// + /// Returns true if the input type is something recognizable as being oen of the standard + /// builtin types. This method will also throw if something is detected as being definitely + /// wrong (e.g., the input type does not descend from at all, + /// or a is declared with a type parameter or + /// something. + /// + private static bool IsStandard(IExceptionContext ectx, Type t) + { + Contracts.AssertValue(ectx); + ectx.AssertValue(t); + if (!typeof(PipelineColumn).IsAssignableFrom(t)) + { + throw ectx.ExceptParam(nameof(t), $"Type {t} was not even of {nameof(PipelineColumn)}"); + } + var gt = t.IsGenericType ? t.GetGenericTypeDefinition() : t; + if (gt != typeof(Scalar<>) && gt != typeof(Key<>) && gt != typeof(Key<,>) && gt != typeof(VarKey<>) && + gt != typeof(Vector<>) && gt != typeof(VarVector<>) && gt != typeof(NormVector<>)) + { + throw ectx.ExceptParam(nameof(t), + $"Type {t} was not one of the standard subclasses of {nameof(PipelineColumn)}"); + } + ectx.Assert(t.IsGenericType); + var ga = t.GetGenericArguments(); + ectx.AssertNonEmpty(ga); + + if (gt == typeof(Key<>) || gt == typeof(Key<,>) || gt == typeof(VarKey<>)) + { + ectx.Assert((gt == typeof(Key<,>) && ga.Length == 2) || ga.Length == 1); + var kt = ga[0]; + if (kt != typeof(byte) && kt != typeof(ushort) && kt != typeof(uint) && kt != typeof(ulong)) + throw ectx.ExceptParam(nameof(t), $"Type parameter {kt.Name} is not a valid type for key"); + return gt != typeof(Key<,>) || IsStandardCore(ga[1]); + } + + ectx.Assert(ga.Length == 1); + return IsStandardCore(ga[0]); + } + + private static bool IsStandardCore(Type t) + { + Contracts.AssertValue(t); + return t == typeof(float) || t == typeof(double) || t == typeof(string) || t == typeof(bool) || + t == typeof(sbyte) || t == typeof(short) || t == typeof(int) || t == typeof(long) || + t == typeof(byte) || t == typeof(ushort) || t == typeof(uint) || t == typeof(ulong) || + t == typeof(TimeSpan) || t == typeof(DateTime) || t == typeof(DateTimeOffset); + } + + /// + /// Returns a .NET type corresponding to the static pipelines that would tend to represent this column. + /// Generally this will return null if it simply does not recognize the type but might throw if + /// there is something seriously wrong with it. + /// + /// The column + /// The .NET type for the static pipelines that should be used to reflect this type, given + /// both the characteristics of the as well as one or two crucial pieces of metadata + private static Type GetTypeOrNull(IColumn col) + { + Contracts.AssertValue(col); + var t = col.Type; + + Type vecType = null; + if (t is VectorType vt) + { + vecType = vt.VectorSize > 0 ? typeof(Vector<>) : typeof(VarVector<>); + // Check normalized subtype of vectors. + if (vt.VectorSize > 0) + { + // Check to see if the column is normalized. + // Once we shift to metadata being a row globally we can also make this a bit more efficient: + var meta = col.Metadata; + if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.IsNormalized, out int normcol)) + { + var normtype = meta.Schema.GetColumnType(normcol); + if (normtype == BoolType.Instance) + { + DvBool val = default; + meta.GetGetter(normcol)(ref val); + if (val.IsTrue) + vecType = typeof(NormVector<>); + } + } + } + t = t.ItemType; + // Fall through to the non-vector case to handle subtypes. + } + Contracts.Assert(!t.IsVector); + + if (t is KeyType kt) + { + Type physType = StaticKind(kt.RawKind); + Contracts.Assert(physType == typeof(byte) || physType == typeof(ushort) + || physType == typeof(uint) || physType == typeof(ulong)); + var keyType = kt.Count > 0 ? typeof(Key<>) : typeof(VarKey<>); + keyType = keyType.MakeGenericType(physType); + + if (kt.Count > 0) + { + // Check to see if we have key value metadata of the appropriate type, size, and whatnot. + var meta = col.Metadata; + if (meta.Schema.TryGetColumnIndex(MetadataUtils.Kinds.KeyValues, out int kvcol)) + { + var kvType = meta.Schema.GetColumnType(kvcol); + if (kvType.VectorSize == kt.Count) + { + Contracts.Assert(kt.Count > 0); + var subtype = GetTypeOrNull(RowColumnUtils.GetColumn(meta, kvcol)); + if (subtype != null && subtype.IsGenericType) + { + var sgtype = subtype.GetGenericTypeDefinition(); + if (sgtype == typeof(NormVector<>) || sgtype == typeof(Vector<>)) + { + var args = subtype.GetGenericArguments(); + Contracts.Assert(args.Length == 1); + keyType = typeof(Key<,>).MakeGenericType(physType, args[0]); + } + } + } + } + } + return vecType?.MakeGenericType(keyType) ?? keyType; + } + + if (t is PrimitiveType pt) + { + Type physType = StaticKind(pt.RawKind); + // Though I am unaware of any existing instances, it is theoretically possible for a + // primitive type to exist, have the same data kind as one of the existing types, and yet + // not be one of the built in types. (E.g., an outside analogy to the key types.) For this + // reason, we must be certain that when we return here we are covering one fo the builtin types. + if (physType != null && ( + pt == NumberType.I1 || pt == NumberType.I2 || pt == NumberType.I4 || pt == NumberType.I8 || + pt == NumberType.U1 || pt == NumberType.U2 || pt == NumberType.U4 || pt == NumberType.U8 || + pt == NumberType.R4 || pt == NumberType.R8 || pt == NumberType.UG || pt == BoolType.Instance || + pt == DateTimeType.Instance || pt == DateTimeZoneType.Instance || pt == TimeSpanType.Instance || + pt == TextType.Instance)) + { + return (vecType ?? typeof(Scalar<>)).MakeGenericType(physType); + } + } + + return null; + } + + /// + /// Note that this can return a different type than the actual physical representation type, e.g., for + /// the return type is , even though we do not use that + /// type for communicating text. + /// + /// The basic type used to represent an item type in the static pipeline + private static Type StaticKind(DataKind kind) + { + switch (kind) + { + // The default kind is reserved for unknown types. + case default(DataKind): return null; + case DataKind.I1: return typeof(sbyte); + case DataKind.I2: return typeof(short); + case DataKind.I4: return typeof(int); + case DataKind.I8: return typeof(long); + + case DataKind.U1: return typeof(byte); + case DataKind.U2: return typeof(ushort); + case DataKind.U4: return typeof(uint); + case DataKind.U8: return typeof(ulong); + case DataKind.U16: return typeof(UInt128); + + case DataKind.R4: return typeof(float); + case DataKind.R8: return typeof(double); + case DataKind.BL: return typeof(bool); + + case DataKind.Text: return typeof(string); + case DataKind.TimeSpan: return typeof(TimeSpan); + case DataKind.DateTime: return typeof(DateTime); + case DataKind.DateTimeZone: return typeof(DateTimeOffset); + + default: + throw Contracts.ExceptParam(nameof(kind), $"Unrecognized type '{kind}'"); + } + } + } +} diff --git a/src/Microsoft.ML.Data/StaticPipe/Transformer.cs b/src/Microsoft.ML.Data/StaticPipe/Transformer.cs new file mode 100644 index 0000000000..4072397928 --- /dev/null +++ b/src/Microsoft.ML.Data/StaticPipe/Transformer.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; + +namespace Microsoft.ML.Data.StaticPipe +{ + public sealed class Transformer : SchemaBearing + where TTransformer : class, ITransformer + { + public TTransformer AsDynamic { get; } + private readonly StaticSchemaShape _inShape; + + internal Transformer(IHostEnvironment env, TTransformer transformer, StaticSchemaShape inShape, StaticSchemaShape outShape) + : base(env, outShape) + { + Env.AssertValue(transformer); + Env.AssertValue(inShape); + AsDynamic = transformer; + _inShape = inShape; + // The ability to check at runtime is limited. We could check during transformation time on the input data view. + } + + public Transformer> + Append(Transformer transformer) + where TNewTransformer : class, ITransformer + { + Env.Assert(nameof(Append) == nameof(LearningPipelineExtensions.Append)); + + var trans = AsDynamic.Append(transformer.AsDynamic); + return new Transformer>(Env, trans, _inShape, transformer.Shape); + } + + public DataView Transform(DataView input) + { + Env.Assert(nameof(Transform) == nameof(ITransformer.Transform)); + Env.CheckValue(input, nameof(input)); + + var view = AsDynamic.Transform(input.AsDynamic); + return new DataView(Env, view, Shape); + } + } +} diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 21076b2d86..35afe6a9f3 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -155,7 +156,7 @@ private sealed class Mapper : MapperBase private ImageGrayscaleTransform _parent; public Mapper(ImageGrayscaleTransform parent, ISchema inputSchema) - :base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _parent = parent; } @@ -232,5 +233,51 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + + private interface IColInput + { + PipelineColumn Input { get; } + } + + internal sealed class OutPipelineColumn : Scalar, IColInput + { + public PipelineColumn Input { get; } + + public OutPipelineColumn(Scalar input) + : base(Reconciler.Inst, input) + { + Contracts.AssertValue(input); + Contracts.Assert(typeof(T) == typeof(Bitmap) || typeof(T) == typeof(UnknownSizeBitmap)); + Input = input; + } + } + + /// + /// Reconciler to an for the . + /// + /// Because we want to use the same reconciler for + /// + /// + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var cols = new (string input, string output)[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var outCol = (IColInput)toOutput[i]; + cols[i] = (inputNames[outCol.Input], outputNames[toOutput[i]]); + } + return new ImageGrayscaleEstimator(env, cols); + } + } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 20e1476feb..d02c261f37 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -239,5 +240,62 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + + internal sealed class OutPipelineColumn : Scalar + { + private readonly Scalar _input; + + public OutPipelineColumn(Scalar path, string relativeTo) + : base(new Reconciler(relativeTo), path) + { + Contracts.AssertValue(path); + _input = path; + } + + /// + /// Reconciler to an for the . + /// + /// + /// We must create a new reconciler per call, because the relative path of + /// is considered a transform-wide option, as it is not specified in . However, we still + /// implement so the analyzer can still equate two of these things if they happen to share the same + /// path, so we can be a bit more efficient with respect to our estimator declarations. + /// + /// + private sealed class Reconciler : EstimatorReconciler, IEquatable + { + private readonly string _relTo; + + public Reconciler(string relativeTo) + { + Contracts.AssertValueOrNull(relativeTo); + _relTo = relativeTo; + } + + public bool Equals(Reconciler other) + => other != null && other._relTo == _relTo; + + public override bool Equals(object obj) + => obj is Reconciler other && Equals(other); + + public override int GetHashCode() + => _relTo?.GetHashCode() ?? 0; + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var cols = new (string input, string output)[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var outCol = (OutPipelineColumn)toOutput[i]; + cols[i] = (inputNames[outCol._input], outputNames[outCol]); + } + return new ImageLoaderEstimator(env, _relTo, cols); + } + } + } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index 0bd5bf7879..3d98bd3430 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -142,10 +143,10 @@ public sealed class ColumnInfo public readonly float Scale; public readonly bool Interleave; - public bool Alpha { get { return (Colors & ColorBits.Alpha) != 0; } } - public bool Red { get { return (Colors & ColorBits.Red) != 0; } } - public bool Green { get { return (Colors & ColorBits.Green) != 0; } } - public bool Blue { get { return (Colors & ColorBits.Blue) != 0; } } + public bool Alpha => (Colors & ColorBits.Alpha) != 0; + public bool Red => (Colors & ColorBits.Red) != 0; + public bool Green => (Colors & ColorBits.Green) != 0; + public bool Blue => (Colors & ColorBits.Blue) != 0; internal ColumnInfo(Column item, Arguments args) { @@ -664,5 +665,70 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + + private interface IColInput + { + Scalar Input { get; } + + ImagePixelExtractorTransform.ColumnInfo MakeColumnInfo(string input, string output); + } + + internal sealed class OutPipelineColumn : Vector, IColInput + { + public Scalar Input { get; } + private static readonly ImagePixelExtractorTransform.Arguments _defaultArgs = new ImagePixelExtractorTransform.Arguments(); + private readonly ImagePixelExtractorTransform.Column _colParam; + + public OutPipelineColumn(Scalar input, ImagePixelExtractorTransform.Column col) + : base(Reconciler.Inst, input) + { + Contracts.AssertValue(input); + Contracts.Assert(typeof(T) == typeof(float) || typeof(T) == typeof(byte)); + Input = input; + _colParam = col; + } + + public ImagePixelExtractorTransform.ColumnInfo MakeColumnInfo(string input, string output) + { + // In principle, the analyzer should only call the the reconciler once for these columns. + Contracts.Assert(_colParam.Source == null); + Contracts.Assert(_colParam.Name == null); + + _colParam.Name = output; + _colParam.Source = input; + return new ImagePixelExtractorTransform.ColumnInfo(_colParam, _defaultArgs); + } + } + + /// + /// Reconciler to an for the . + /// + /// Because we want to use the same reconciler for + /// + /// + private sealed class Reconciler : EstimatorReconciler + { + /// + /// Because there are no global settings that cannot be overridden, we can always just use the same reconciler. + /// + public static Reconciler Inst = new Reconciler(); + + private Reconciler() { } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var cols = new ImagePixelExtractorTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var outCol = (IColInput)toOutput[i]; + cols[i] = outCol.MakeColumnInfo(inputNames[outCol.Input], outputNames[toOutput[i]]); + } + return new ImagePixelExtractorEstimator(env, cols); + } + } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index ac11c7fa8d..30cf830758 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text; using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe.Runtime; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; @@ -458,5 +459,58 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) return new SchemaShape(result.Values); } + + internal sealed class OutPipelineColumn : Scalar + { + private readonly PipelineColumn _input; + private readonly int _width; + private readonly int _height; + private readonly ImageResizerTransform.ResizingKind _resizing; + private readonly ImageResizerTransform.Anchor _cropAnchor; + + public OutPipelineColumn(PipelineColumn input, int width, int height, + ImageResizerTransform.ResizingKind resizing, ImageResizerTransform.Anchor cropAnchor) + : base(Reconciler.Inst, input) + { + Contracts.AssertValue(input); + _input = input; + _width = width; + _height = height; + _resizing = resizing; + _cropAnchor = cropAnchor; + } + + private ImageResizerTransform.ColumnInfo MakeColumnInfo(string input, string output) + => new ImageResizerTransform.ColumnInfo(input, output, _width, _height, _resizing, _cropAnchor); + + /// + /// Reconciler to an for the . + /// + /// + /// + private sealed class Reconciler : EstimatorReconciler + { + public static Reconciler Inst = new Reconciler(); + + private Reconciler() + { + } + + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + var cols = new ImageResizerTransform.ColumnInfo[toOutput.Length]; + for (int i = 0; i < toOutput.Length; ++i) + { + var outCol = (OutPipelineColumn)toOutput[i]; + cols[i] = outCol.MakeColumnInfo(inputNames[outCol._input], outputNames[outCol]); + } + return new ImageResizerEstimator(env, cols); + } + } + } } } diff --git a/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs b/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs new file mode 100644 index 0000000000..c6d64ec88e --- /dev/null +++ b/src/Microsoft.ML.ImageAnalytics/ImageStaticPipe.cs @@ -0,0 +1,169 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Drawing; +using Microsoft.ML.Data.StaticPipe.Runtime; + +namespace Microsoft.ML.Runtime.ImageAnalytics +{ + /// + /// A type used in the generic argument to . We must simultaneously distinguish + /// between a of fixed (with ) and unfixed (with this type), + /// in the static pipelines. + /// + public class UnknownSizeBitmap { private UnknownSizeBitmap() { } } + + /// + /// Extension methods for the static-pipeline over objects. + /// + public static class ImageStaticPipe + { + /// + /// Load an image from an input column that holds the paths to images. + /// + /// The scalar text column that holds paths to the images + /// If specified, paths are considered to be relative to this directory. + /// However, since the transform can be persisted across machines, it is generally considered more + /// safe for users to simply always make their input paths absolute. + /// The loaded images + /// + public static Scalar LoadAsImage(this Scalar path, string relativeTo = null) + { + Contracts.CheckValue(path, nameof(path)); + Contracts.CheckValueOrNull(relativeTo); + return new ImageLoaderEstimator.OutPipelineColumn(path, relativeTo); + } + + /// + /// Converts the image to grayscale. + /// + /// The image to convert + /// The grayscale images + /// + public static Scalar AsGrayscale(this Scalar input) + { + Contracts.CheckValue(input, nameof(input)); + return new ImageGrayscaleEstimator.OutPipelineColumn(input); + } + + /// + /// Converts the image to grayscale. + /// + /// The image to convert + /// The grayscale images + /// + public static Scalar AsGrayscale(this Scalar input) + { + Contracts.CheckValue(input, nameof(input)); + return new ImageGrayscaleEstimator.OutPipelineColumn(input); + } + + /// + /// Given a column of images of unfixed size, resize the images so they have uniform size. + /// + /// The input images + /// The width to resize to + /// The height to resize to + /// The type of resizing to do + /// If cropping is necessary, at what position will the image be fixed? + /// The now uniformly sized images + /// + public static Scalar Resize(this Scalar input, int width, int height, + ImageResizerTransform.ResizingKind resizing = ImageResizerTransform.ResizingKind.IsoCrop, + ImageResizerTransform.Anchor cropAnchor = ImageResizerTransform.Anchor.Center) + { + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckParam(width > 0, nameof(width), "Must be positive"); + Contracts.CheckParam(height > 0, nameof(height), "Must be positive"); + Contracts.CheckParam(Enum.IsDefined(typeof(ImageResizerTransform.ResizingKind), resizing), nameof(resizing), "Undefined value detected"); + Contracts.CheckParam(Enum.IsDefined(typeof(ImageResizerTransform.Anchor), cropAnchor), nameof(cropAnchor), "Undefined value detected"); + + return new ImageResizerEstimator.OutPipelineColumn(input, width, height, resizing, cropAnchor); + } + + /// + /// Given a column of images, resize them to a new fixed size. + /// + /// The input images + /// The width to resize to + /// The height to resize to + /// The type of resizing to do + /// If cropping is necessary, at what + /// The resized images + /// + public static Scalar Resize(this Scalar input, int width, int height, + ImageResizerTransform.ResizingKind resizing = ImageResizerTransform.ResizingKind.IsoCrop, + ImageResizerTransform.Anchor cropAnchor = ImageResizerTransform.Anchor.Center) + { + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckParam(width > 0, nameof(width), "Must be positive"); + Contracts.CheckParam(height > 0, nameof(height), "Must be positive"); + Contracts.CheckParam(Enum.IsDefined(typeof(ImageResizerTransform.ResizingKind), resizing), nameof(resizing), "Undefined value detected"); + Contracts.CheckParam(Enum.IsDefined(typeof(ImageResizerTransform.Anchor), cropAnchor), nameof(cropAnchor), "Undefined value detected"); + + return new ImageResizerEstimator.OutPipelineColumn(input, width, height, resizing, cropAnchor); + } + + /// + /// Vectorizes the image as the numeric values of its pixels converted and possibly transformed to floating point values. + /// The output vector is output in height then width major order, with the channels being the most minor (if + /// is true) or major (if is false) dimension. + /// + /// The input image to extract + /// Whether the alpha channel should be extracted + /// Whether the red channel should be extracted + /// Whether the green channel should be extracted + /// Whether the blue channel should be extracted + /// Whether the pixel values should be interleaved, as opposed to being separated by channel + /// Scale the normally 0 through 255 pixel values by this amount + /// Add this amount to the pixel values, before scaling + /// The vectorized image + /// + public static Vector ExtractPixels(this Scalar input, bool useAlpha = false, bool useRed = true, + bool useGreen = true, bool useBlue = true, bool interleaveArgb = false, float scale = 1.0f, float offset = 0.0f) + { + var colParams = new ImagePixelExtractorTransform.Column + { + UseAlpha = useAlpha, + UseRed = useRed, + UseGreen = useGreen, + UseBlue = useBlue, + InterleaveArgb = interleaveArgb, + Scale = scale, + Offset = offset, + Convert = true + }; + return new ImagePixelExtractorEstimator.OutPipelineColumn(input, colParams); + } + + /// + /// Vectorizes the image as the numeric byte values of its pixels. + /// The output vector is output in height then width major order, with the channels being the most minor (if + /// is true) or major (if is false) dimension. + /// + /// The input image to extract + /// Whether the alpha channel should be extracted + /// Whether the red channel should be extracted + /// Whether the green channel should be extracted + /// Whether the blue channel should be extracted + /// Whether the pixel values should be interleaved, as opposed to being separated by channel + /// The vectorized image + /// + public static Vector ExtractPixelsAsBytes(this Scalar input, bool useAlpha = false, bool useRed = true, + bool useGreen = true, bool useBlue = true, bool interleaveArgb = false) + { + var colParams = new ImagePixelExtractorTransform.Column + { + UseAlpha = useAlpha, + UseRed = useRed, + UseGreen = useGreen, + UseBlue = useBlue, + InterleaveArgb = interleaveArgb, + Convert = false + }; + return new ImagePixelExtractorEstimator.OutPipelineColumn(input, colParams); + } + } +} diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/ContractsCheckTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs similarity index 96% rename from test/Microsoft.ML.CodeAnalyzer.Tests/ContractsCheckTest.cs rename to test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs index a004994f4f..ff528bfcce 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/ContractsCheckTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ContractsCheckTest.cs @@ -2,14 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System.IO; -using System.Reflection; -using System.Threading; using Microsoft.ML.CodeAnalyzer.Tests.Helpers; using Xunit; -using Xunit.Abstractions; -namespace Microsoft.ML.CodeAnalyzer.Tests +namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class ContractsCheckTest : DiagnosticVerifier { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/InstanceInitializerTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/InstanceInitializerTest.cs similarity index 93% rename from test/Microsoft.ML.CodeAnalyzer.Tests/InstanceInitializerTest.cs rename to test/Microsoft.ML.CodeAnalyzer.Tests/Code/InstanceInitializerTest.cs index 8baed00840..6c2c154bb7 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/InstanceInitializerTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/InstanceInitializerTest.cs @@ -3,9 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Microsoft.ML.InternalCodeAnalyzer; using Xunit; -namespace Microsoft.ML.CodeAnalyzer.Tests +namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class InstanceInitializerTest : DiagnosticVerifier { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/NameTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/NameTest.cs similarity index 98% rename from test/Microsoft.ML.CodeAnalyzer.Tests/NameTest.cs rename to test/Microsoft.ML.CodeAnalyzer.Tests/Code/NameTest.cs index 2f87acd01a..5b79d92e2e 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/NameTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/NameTest.cs @@ -3,9 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Microsoft.ML.InternalCodeAnalyzer; using Xunit; -namespace Microsoft.ML.CodeAnalyzer.Tests +namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class NameTest : DiagnosticVerifier { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/ParameterVariableNameTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ParameterVariableNameTest.cs similarity index 95% rename from test/Microsoft.ML.CodeAnalyzer.Tests/ParameterVariableNameTest.cs rename to test/Microsoft.ML.CodeAnalyzer.Tests/Code/ParameterVariableNameTest.cs index 673ead697d..3450e25c17 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/ParameterVariableNameTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/ParameterVariableNameTest.cs @@ -3,9 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Microsoft.ML.InternalCodeAnalyzer; using Xunit; -namespace Microsoft.ML.CodeAnalyzer.Tests +namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class ParameterVariableNameTest : DiagnosticVerifier { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Code/README.md b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/README.md new file mode 100644 index 0000000000..1ef55cfffe --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/README.md @@ -0,0 +1 @@ +The tests in this directory are for testing the internal code analyzer (that is, the implementation of ML.NET itself), as opposed to general analyzer that ships with the library. \ No newline at end of file diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/SingleVariableDeclarationTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/SingleVariableDeclarationTest.cs similarity index 93% rename from test/Microsoft.ML.CodeAnalyzer.Tests/SingleVariableDeclarationTest.cs rename to test/Microsoft.ML.CodeAnalyzer.Tests/Code/SingleVariableDeclarationTest.cs index add947f5dd..13a1533217 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/SingleVariableDeclarationTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/SingleVariableDeclarationTest.cs @@ -3,9 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Microsoft.ML.InternalCodeAnalyzer; using Xunit; -namespace Microsoft.ML.CodeAnalyzer.Tests +namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class SingleVariableDeclarationTest : DiagnosticVerifier { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/TypeParamNameTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/TypeParamNameTest.cs similarity index 93% rename from test/Microsoft.ML.CodeAnalyzer.Tests/TypeParamNameTest.cs rename to test/Microsoft.ML.CodeAnalyzer.Tests/Code/TypeParamNameTest.cs index b9de9bf42f..f7c1c38030 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/TypeParamNameTest.cs +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Code/TypeParamNameTest.cs @@ -3,9 +3,10 @@ // See the LICENSE file in the project root for more information. using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Microsoft.ML.InternalCodeAnalyzer; using Xunit; -namespace Microsoft.ML.CodeAnalyzer.Tests +namespace Microsoft.ML.InternalCodeAnalyzer.Tests { public sealed class TypeParamNameTest : DiagnosticVerifier { diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj b/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj index f83d9682fe..10359a4739 100644 --- a/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Microsoft.ML.CodeAnalyzer.Tests.csproj @@ -5,13 +5,14 @@ %(RecursiveDir)%(Filename)%(Extension) - - + + + \ No newline at end of file diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs new file mode 100644 index 0000000000..3cb932e992 --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResource.cs @@ -0,0 +1,44 @@ +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data.StaticPipe; + +namespace Bubba +{ + class Foo + { + public static void Bar() + { + IHostEnvironment env = null; + var text = TextLoader.CreateReader(env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadText(1), + numericFeatures: ctx.LoadFloat(2, 5))); + + var est = Estimator.MakeNew(text); + // This should work. + est.Append(r => r.text); + // These should not. + est.Append(r => 5); + est.Append(r => new { r.text }); + est.Append(r => Tuple.Create(r.text, r.numericFeatures)); + // This should work. + est.Append(r => (a: r.text, b: r.label, c: (d: r.text, r.label))); + // This should not, and it should indicate a path to the problematic item. + est.Append(r => (a: r.text, b: r.label, c: (d: r.text, "yo"))); + + // Check a different entrance into static land now, with one of the asserts. + var view = text.Read(null).AsDynamic; + // Despite the fact that the names are all wrong, this should still work + // from the point of view of this analyzer. + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + awhile: c.KeyU1.I4Values.Vector)); + // However, this should not. + view.AssertStatic(env, c => ( + and: c.KeyU4.TextValues.Scalar, + listen: "dawg")); + } + } +} diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResourceChained.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResourceChained.cs new file mode 100644 index 0000000000..43316dc567 --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/Resources/TypeIsSchemaShapeResourceChained.cs @@ -0,0 +1,63 @@ +using System; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; + +namespace Bubba +{ + class Foo + { + public static void Bar() + { + DataReader Foo1(Func m) + { + IHostEnvironment env = null; + // We ought to fail here. + return TextLoader.CreateReader(env, m); + } + + DataReader Foo2<[IsShape] T>(Func m) + { + IHostEnvironment env = null; + // We ought not to fail here due to that [IsShape], but calls to this method might fail. + return TextLoader.CreateReader(env, m); + } + + DataReader Foo3(Func m) + where T : PipelineColumn + { + IHostEnvironment env = null; + // This should work. + return TextLoader.CreateReader(env, m); + } + + DataReader Foo4(Func m) + where T : IEnumerable + { + IHostEnvironment env = null; + // This should not work. + return TextLoader.CreateReader(env, m); + } + + void Scratch() + { + // Neither of these two should fail here, though the method they're calling ought to fail. + var f1 = Foo1(ctx => ( + label: ctx.LoadBool(0), text: ctx.LoadText(1))); + var f2 = Foo1(ctx => ( + label: ctx.LoadBool(0), text: "hi")); + + // The first should succeed, the second should fail. + var f3 = Foo2(ctx => ( + label: ctx.LoadBool(0), text: ctx.LoadText(1))); + var f4 = Foo2(ctx => ( + label: ctx.LoadBool(0), text: "hi")); + + // This should succeed. + var f5 = Foo3(ctx => ctx.LoadBool(0)); + } + } + } +} diff --git a/test/Microsoft.ML.CodeAnalyzer.Tests/TypeIsSchemaShapeTest.cs b/test/Microsoft.ML.CodeAnalyzer.Tests/TypeIsSchemaShapeTest.cs new file mode 100644 index 0000000000..5a8ad47b21 --- /dev/null +++ b/test/Microsoft.ML.CodeAnalyzer.Tests/TypeIsSchemaShapeTest.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.CodeAnalyzer.Tests.Helpers; +using Xunit; + +namespace Microsoft.ML.Analyzer.Tests +{ + public sealed class TypeIsSchemaShapeTest : DiagnosticVerifier + { + private static string _srcResource; + internal static string Source => TestUtils.EnsureSourceLoaded(ref _srcResource, "TypeIsSchemaShapeResource.cs"); + + [Fact] + public void ReturnTypeIsSchemaShape() + { + var analyzer = GetCSharpDiagnosticAnalyzer(); + var diag = analyzer.SupportedDiagnostics[0]; + + string p(string i = "") => string.IsNullOrEmpty(i) ? "" : $" of item {i}"; + + var expected = new DiagnosticResult[] { + diag.CreateDiagnosticResult(23, 13, p()), + diag.CreateDiagnosticResult(24, 13, p()), + diag.CreateDiagnosticResult(25, 13, p()), + diag.CreateDiagnosticResult(29, 13, p("c.Item2")), + diag.CreateDiagnosticResult(39, 13, p("listen")), + }; + + VerifyCSharpDiagnostic(Source, expected); + } + + private static string _srcResourceChained; + internal static string SourceChained => TestUtils.EnsureSourceLoaded( + ref _srcResourceChained, "TypeIsSchemaShapeResourceChained.cs"); + + [Fact] + public void ReturnTypeIsSchemaShapeChained() + { + // This is a somewhat more complex example, where instead of direct usage the user of the API is devising their own + // function where the shape type is a generic type parameter. In this case, we would ideally like the analysis to get + // chained out of their function. + var analyzer = GetCSharpDiagnosticAnalyzer(); + var diag = analyzer.SupportedDiagnostics[0]; + var diagTp = analyzer.SupportedDiagnostics[1]; + + string p(string i = "") => string.IsNullOrEmpty(i) ? "" : $" of item {i}"; + + var expected = new DiagnosticResult[] { + diagTp.CreateDiagnosticResult(18, 24, "T"), + diagTp.CreateDiagnosticResult(41, 24, "T"), + diag.CreateDiagnosticResult(55, 26, p("text")), + }; + + VerifyCSharpDiagnostic(SourceChained, expected); + } + } +} diff --git a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs index ed1780c6d7..b5b53677ab 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestTransposer.cs @@ -8,6 +8,7 @@ using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.TestFramework; using Xunit; using Xunit.Abstractions; @@ -234,7 +235,7 @@ public void TransposerSaverLoaderTest() { TransposeSaver saver = new TransposeSaver(Env, new TransposeSaver.Arguments()); saver.SaveData(mem, view, Utils.GetIdentityPermutation(view.Schema.ColumnCount)); - src = new BytesSource(mem.ToArray()); + src = new BytesStreamSource(mem.ToArray()); } TransposeLoader loader = new TransposeLoader(Env, new TransposeLoader.Arguments(), src); // First check whether this as an IDataView yields the same values. diff --git a/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs new file mode 100644 index 0000000000..5b084d601e --- /dev/null +++ b/test/Microsoft.ML.StaticPipelineTesting/ImageAnalyticsTests.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.ImageAnalytics; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.StaticPipelineTesting +{ + public sealed class ImageAnalyticsTests : MakeConsoleWork + { + public ImageAnalyticsTests(ITestOutputHelper output) + : base(output) + { + } + + [Fact] + public void SimpleImageSmokeTest() + { + var env = new TlcEnvironment(new SysRandom(0), verbose: true); + + var reader = TextLoader.CreateReader(env, + ctx => ctx.LoadText(0).LoadAsImage().AsGrayscale().Resize(10, 8).ExtractPixels()); + + var schema = reader.AsDynamic.GetOutputSchema(); + Assert.True(schema.TryGetColumnIndex("Data", out int col), "Could not find 'Data' column"); + var type = schema.GetColumnType(col); + Assert.True(type.IsKnownSizeVector, $"Type was supposed to be known size vector but was instead '{type}'"); + var vecType = type.AsVector; + Assert.Equal(NumberType.R4, vecType.ItemType); + Assert.Equal(3, vecType.DimCount); + Assert.Equal(3, vecType.GetDim(0)); + Assert.Equal(8, vecType.GetDim(1)); + Assert.Equal(10, vecType.GetDim(2)); + } + } +} diff --git a/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj b/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj new file mode 100644 index 0000000000..ad65c49804 --- /dev/null +++ b/test/Microsoft.ML.StaticPipelineTesting/Microsoft.ML.StaticPipelineTesting.csproj @@ -0,0 +1,12 @@ + + + CORECLR + + + + + + + + + \ No newline at end of file diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs new file mode 100644 index 0000000000..8313d94e7f --- /dev/null +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeFakes.cs @@ -0,0 +1,218 @@ +using Microsoft.ML.Core.Data; +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; + +// Holds some classes that superficially represent classes, at least sufficiently to give the idea of the +// statically typed columnar estimator helper API. As more "real" examples of the static functions get +// added, this file will gradully disappear. + +namespace FakeStaticPipes +{ + /// + /// This is a reconciler that doesn't really do anything, just a fake for testing the infrastructure. + /// + internal sealed class FakeTransformReconciler : EstimatorReconciler + { + private readonly string _name; + + public FakeTransformReconciler(string name) + { + _name = name; + } + + public override IEstimator Reconcile( + IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Console.WriteLine($"Constructing {_name} estimator!"); + + foreach (var col in toOutput) + { + if ((((IDeps)col).Deps?.Length ?? 0) == 0) + Console.WriteLine($" Will make '{outputNames[col]}' from nothing"); + else + { + Console.WriteLine($" Will make '{outputNames[col]}' out of " + + string.Join(", ", ((IDeps)col).Deps.Select(d => $"'{inputNames[d]}'"))); + } + } + + return new FakeEstimator(); + } + + private sealed class FakeEstimator : IEstimator + { + public ITransformer Fit(IDataView input) => throw new NotImplementedException(); + public SchemaShape GetOutputSchema(SchemaShape inputSchema) => throw new NotImplementedException(); + } + + private interface IDeps { PipelineColumn[] Deps { get; } } + + private sealed class AScalar : Scalar, IDeps { public AScalar(Reconciler rec, PipelineColumn[] dependencies) : base(rec, dependencies) { Deps = dependencies; } public PipelineColumn[] Deps { get; } } + private sealed class AVector : Vector, IDeps { public AVector(Reconciler rec, PipelineColumn[] dependencies) : base(rec, dependencies) { Deps = dependencies; } public PipelineColumn[] Deps { get; } } + private sealed class AVarVector : VarVector, IDeps { public AVarVector(Reconciler rec, PipelineColumn[] dependencies) : base(rec, dependencies) { Deps = dependencies; } public PipelineColumn[] Deps { get; } } + private sealed class AKey : Key, IDeps { public AKey(Reconciler rec, PipelineColumn[] dependencies) : base(rec, dependencies) { Deps = dependencies; } public PipelineColumn[] Deps { get; } } + private sealed class AKey : Key, IDeps { public AKey(Reconciler rec, PipelineColumn[] dependencies) : base(rec, dependencies) { Deps = dependencies; } public PipelineColumn[] Deps { get; } } + private sealed class AVarKey : VarKey, IDeps { public AVarKey(Reconciler rec, PipelineColumn[] dependencies) : base(rec, dependencies) { Deps = dependencies; } public PipelineColumn[] Deps { get; } } + + public Scalar Scalar(params PipelineColumn[] dependencies) => new AScalar(this, dependencies); + public Vector Vector(params PipelineColumn[] dependencies) => new AVector(this, dependencies); + public VarVector VarVector(params PipelineColumn[] dependencies) => new AVarVector(this, dependencies); + public Key Key(params PipelineColumn[] dependencies) => new AKey(this, dependencies); + public Key Key(params PipelineColumn[] dependencies) => new AKey(this, dependencies); + public VarKey VarKey(params PipelineColumn[] dependencies) => new AVarKey(this, dependencies); + } + + public static class ConcatTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("Concat"); + + public sealed class ScalarOrVector : ScalarOrVectorOrVarVector + { + private ScalarOrVector(PipelineColumn col) : base(col) { } + public static implicit operator ScalarOrVector(Scalar c) => new ScalarOrVector(c); + public static implicit operator ScalarOrVector(Vector c) => new ScalarOrVector(c); + } + + private interface IContainsColumn + { + PipelineColumn WrappedColumn { get; } + } + + + public class ScalarOrVectorOrVarVector : IContainsColumn + { + private readonly PipelineColumn _wrappedColumn; + PipelineColumn IContainsColumn.WrappedColumn => _wrappedColumn; + + private protected ScalarOrVectorOrVarVector(PipelineColumn col) + { + _wrappedColumn = col; + } + + public static implicit operator ScalarOrVectorOrVarVector(VarVector c) + => new ScalarOrVectorOrVarVector(c); + } + + private static PipelineColumn[] Helper(PipelineColumn first, IList> list) + { + PipelineColumn[] retval = new PipelineColumn[list.Count + 1]; + retval[0] = first; + for (int i = 0; i < list.Count; ++i) + retval[i + 1] = ((IContainsColumn)list[i]).WrappedColumn; + return retval; + } + + public static Vector ConcatWith(this Scalar me, params ScalarOrVector[] i) + => _rec.Vector(Helper(me, i)); + public static Vector ConcatWith(this Vector me, params ScalarOrVector[] i) + => _rec.Vector(Helper(me, i)); + + public static VarVector ConcatWith(this Scalar me, params ScalarOrVectorOrVarVector[] i) + => _rec.VarVector(Helper(me, i)); + public static VarVector ConcatWith(this Vector me, params ScalarOrVectorOrVarVector[] i) + => _rec.VarVector(Helper(me, i)); + public static VarVector ConcatWith(this VarVector me, params ScalarOrVectorOrVarVector[] i) + => _rec.VarVector(Helper(me, i)); + } + + public static class NormalizeTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("Normalize"); + + public static Vector Normalize(this Vector me) + => _rec.Vector(me); + + public static Vector Normalize(this Vector me) + => _rec.Vector(me); + } + + public static class WordTokenizeTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("WordTokenize"); + + public static VarVector Tokenize(this Scalar me) + => _rec.VarVector(me); + public static VarVector Tokenize(this Vector me) + => _rec.VarVector(me); + public static VarVector Tokenize(this VarVector me) + => _rec.VarVector(me); + } + + public static class TermTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("Term"); + + public static Key Dictionarize(this Scalar me) + => _rec.Key(me); + public static Vector> Dictionarize(this Vector me) + => _rec.Vector>(me); + public static VarVector> Dictionarize(this VarVector me) + => _rec.VarVector>(me); + } + + public static class KeyToVectorTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("KeyToVector"); + + public static Vector BagVectorize(this VarVector> me) + => _rec.Vector(me); + public static Vector BagVectorize(this VarVector> me) + => _rec.Vector(me); + } + + public static class TextTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("TextTransform"); + + /// + /// Performs text featurization on the input text. This will tokenize, do n-gram featurization, + /// dictionary based term mapping, and finally produce a word-bag vector for the output. + /// + /// The text to featurize + /// + public static Vector TextFeaturizer(this Scalar me, bool keepDiacritics = true) + => _rec.Vector(me); + } + + public static class TrainerTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("LinearBinaryClassification"); + + /// + /// Trains a linear predictor using logistic regression. + /// + /// The target label for this binary classification task + /// The features to train on. Should be normalized. + /// A tuple of columns representing the score, the calibrated score as a probability, and the boolean predicted label + public static (Scalar score, Scalar probability, Scalar predictedLabel) TrainLinearClassification(this Scalar label, Vector features) + => (_rec.Scalar(label, features), _rec.Scalar(label, features), _rec.Scalar(label, features)); + } + + public static class HashTransformExtensions + { + private static FakeTransformReconciler _rec = new FakeTransformReconciler("Hash"); + + public static Key Hash(this Scalar me) + => _rec.Key(me); + public static Key Hash(this Scalar me, int invertHashTokens) + => _rec.Key(me); + public static Vector> Hash(this Vector me) + => _rec.Vector>(me); + public static Vector> Hash(this Vector me, int invertHashTokens) + => _rec.Vector>(me); + public static VarVector> Hash(this VarVector me) + => _rec.VarVector>(me); + public static VarVector> Hash(this VarVector me, int invertHashTokens) + => _rec.VarVector>(me); + } +} diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs new file mode 100644 index 0000000000..814ce5fbaf --- /dev/null +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -0,0 +1,260 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Data.StaticPipe; +using Microsoft.ML.Data.StaticPipe.Runtime; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.TestFramework; +using System; +using System.Collections.Generic; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.StaticPipelineTesting +{ + public abstract class MakeConsoleWork : IDisposable + { + private readonly ITestOutputHelper _output; + private readonly TextWriter _originalOut; + private readonly TextWriter _textWriter; + + public MakeConsoleWork(ITestOutputHelper output) + { + _output = output; + _originalOut = Console.Out; + _textWriter = new StringWriter(); + Console.SetOut(_textWriter); + } + + public void Dispose() + { + _output.WriteLine(_textWriter.ToString()); + Console.SetOut(_originalOut); + } + } + + public sealed class StaticPipeTests : MakeConsoleWork + { + public StaticPipeTests(ITestOutputHelper output) + : base(output) + { + } + + private void CheckSchemaHasColumn(ISchema schema, string name, out int idx) + => Assert.True(schema.TryGetColumnIndex(name, out idx), "Could not find column '" + name + "'"); + + [Fact] + public void SimpleTextLoaderCopyColumnsTest() + { + var env = new TlcEnvironment(new SysRandom(0), verbose: true); + + const string data = "0 hello 3.14159 -0 2\n" + + "1 1 2 4 15"; + var dataSource = new BytesStreamSource(data); + + var text = TextLoader.CreateReader(env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadText(1), + numericFeatures: ctx.LoadFloat(2, null)), // If fit correctly, this ought to be equivalent to max of 4, that is, length of 3. + dataSource, separator: ' '); + + // While we have a type-safe wrapper for `IDataView` it is utterly useless except as an input to the `Fit` functions + // of the other statically typed wrappers. We perhaps ought to make it useful in its own right, but perhaps not now. + // For now, just operate over the actual `IDataView`. + var textData = text.Read(dataSource).AsDynamic; + + var schema = textData.Schema; + // First verify that the columns are there. There ought to be at least one column corresponding to the identifiers in the tuple. + CheckSchemaHasColumn(schema, "label", out int labelIdx); + CheckSchemaHasColumn(schema, "text", out int textIdx); + CheckSchemaHasColumn(schema, "numericFeatures", out int numericFeaturesIdx); + // Next verify they have the expected types. + Assert.Equal(BoolType.Instance, schema.GetColumnType(labelIdx)); + Assert.Equal(TextType.Instance, schema.GetColumnType(textIdx)); + Assert.Equal(new VectorType(NumberType.R4, 3), schema.GetColumnType(numericFeaturesIdx)); + // Next actually inspect the data. + using (var cursor = textData.GetRowCursor(c => true)) + { + var labelGetter = cursor.GetGetter(labelIdx); + var textGetter = cursor.GetGetter(textIdx); + var numericFeaturesGetter = cursor.GetGetter>(numericFeaturesIdx); + + DvBool labelVal = default; + DvText textVal = default; + VBuffer numVal = default; + + void CheckValuesSame(bool bl, string tx, float v0, float v1, float v2) + { + labelGetter(ref labelVal); + textGetter(ref textVal); + numericFeaturesGetter(ref numVal); + + Assert.Equal((DvBool)bl, labelVal); + Assert.Equal(new DvText(tx), textVal); + Assert.Equal(3, numVal.Length); + Assert.Equal(v0, numVal.GetItemOrDefault(0)); + Assert.Equal(v1, numVal.GetItemOrDefault(1)); + Assert.Equal(v2, numVal.GetItemOrDefault(2)); + } + + Assert.True(cursor.MoveNext(), "Could not move even to first row"); + CheckValuesSame(false, "hello", 3.14159f, -0f, 2f); + Assert.True(cursor.MoveNext(), "Could not move to second row"); + CheckValuesSame(true, "1", 2f, 4f, 15f); + Assert.False(cursor.MoveNext(), "Moved to third row, but there should have been only two"); + } + + // The next step where we shuffle the names around a little bit is one where we are + // testing out the implicit usage of copy columns. + + var est = Estimator.MakeNew(text).Append(r => (text: r.label, label: r.numericFeatures)); + var newText = text.Append(est); + var newTextData = newText.Fit(dataSource).Read(dataSource); + + schema = newTextData.AsDynamic.Schema; + // First verify that the columns are there. There ought to be at least one column corresponding to the identifiers in the tuple. + CheckSchemaHasColumn(schema, "label", out labelIdx); + CheckSchemaHasColumn(schema, "text", out textIdx); + // Next verify they have the expected types. + Assert.Equal(BoolType.Instance, schema.GetColumnType(textIdx)); + Assert.Equal(new VectorType(NumberType.R4, 3), schema.GetColumnType(labelIdx)); + } + + private static KeyValuePair P(string name, ColumnType type) + => new KeyValuePair(name, type); + + [Fact] + public void StaticPipeAssertSimple() + { + var env = new TlcEnvironment(new SysRandom(0), verbose: true); + var schema = new SimpleSchema(env, + P("hello", TextType.Instance), + P("my", new VectorType(NumberType.I8, 5)), + P("friend", new KeyType(DataKind.U4, 0, 3))); + var view = new EmptyDataView(env, schema); + + view.AssertStatic(env, c => ( + my: c.I8.Vector, + friend: c.KeyU4.NoValue.Scalar, + hello: c.Text.Scalar + )); + } + + private sealed class MetaCounted : ICounted + { + public long Position => 0; + public long Batch => 0; + public ValueGetter GetIdGetter() => (ref UInt128 v) => v = default; + } + + [Fact] + public void StaticPipeAssertKeys() + { + var env = new TlcEnvironment(new SysRandom(0), verbose: true); + var counted = new MetaCounted(); + + // We'll test a few things here. First, the case where the key-value metadata is text. + var metaValues1 = new VBuffer(3, new[] { new DvText("a"), new DvText("b"), new DvText("c") }); + var meta1 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(TextType.Instance, 3), ref metaValues1); + uint value1 = 2; + var col1 = RowColumnUtils.GetColumn("stay", new KeyType(DataKind.U4, 0, 3), ref value1, RowColumnUtils.GetRow(counted, meta1)); + + // Next the case where those values are ints. + var metaValues2 = new VBuffer(3, new DvInt4[] { 1, 2, 3, 4 }); + var meta2 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, new VectorType(NumberType.I4, 4), ref metaValues2); + var value2 = new VBuffer(2, 0, null, null); + var col2 = RowColumnUtils.GetColumn("awhile", new VectorType(new KeyType(DataKind.U1, 2, 4), 2), ref value2, RowColumnUtils.GetRow(counted, meta2)); + + // Then the case where a value of that kind exists, but is of not of the right kind, in which case it should not be identified as containing that metadata. + var metaValues3 = (float)2; + var meta3 = RowColumnUtils.GetColumn(MetadataUtils.Kinds.KeyValues, NumberType.R4, ref metaValues3); + var value3 = (ushort)1; + var col3 = RowColumnUtils.GetColumn("and", new KeyType(DataKind.U2, 0, 2), ref value3, RowColumnUtils.GetRow(counted, meta3)); + + // Then a final case where metadata of that kind is actaully simply altogether absent. + var value4 = new VBuffer(5, 0, null, null); + var col4 = RowColumnUtils.GetColumn("listen", new VectorType(new KeyType(DataKind.U4, 0, 2)), ref value4); + + // Finally compose a trivial data view out of all this. + var row = RowColumnUtils.GetRow(counted, col1, col2, col3, col4); + var view = RowCursorUtils.RowAsDataView(env, row); + + // Whew! I'm glad that's over with. Let us start running the test in ernest. + // First let's do a direct match of the types to ensure that works. + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + awhile: c.KeyU1.I4Values.Vector, + and: c.KeyU2.NoValue.Scalar, + listen: c.KeyU4.NoValue.VarVector)); + + // Next let's match against the superclasses (where no value types are + // asserted), to ensure that the less specific case still passes. + view.AssertStatic(env, c => ( + stay: c.KeyU4.NoValue.Scalar, + awhile: c.KeyU1.NoValue.Vector, + and: c.KeyU2.NoValue.Scalar, + listen: c.KeyU4.NoValue.VarVector)); + + // Here we assert a subset. + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + awhile: c.KeyU1.I4Values.Vector)); + + // OK. Now we've confirmed the basic stuff works, let's check other scenarios. + // Due to the fact that we cannot yet assert only a *single* column, these always appear + // in at least pairs. + + // First try to get the right type of exception to test against. + Type e = null; + try + { + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + awhile: c.KeyU2.I4Values.Vector)); + } + catch (Exception eCaught) + { + e = eCaught.GetType(); + } + Assert.NotNull(e); + + // What if the key representation type is wrong? + Assert.Throws(e, () => + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + awhile: c.KeyU2.I4Values.Vector))); + + // What if the key value type is wrong? + Assert.Throws(e, () => + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + awhile: c.KeyU1.I2Values.Vector))); + + // Same two tests, but for scalar? + Assert.Throws(e, () => + view.AssertStatic(env, c => ( + stay: c.KeyU2.TextValues.Scalar, + awhile: c.KeyU1.I2Values.Vector))); + + Assert.Throws(e, () => + view.AssertStatic(env, c => ( + stay: c.KeyU4.BoolValues.Scalar, + awhile: c.KeyU1.I2Values.Vector))); + + // How about if we misidentify the vectorness? + Assert.Throws(e, () => + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Vector, + awhile: c.KeyU1.I2Values.Vector))); + + // How about the names? + Assert.Throws(e, () => + view.AssertStatic(env, c => ( + stay: c.KeyU4.TextValues.Scalar, + alot: c.KeyU1.I4Values.Vector))); + } + } +} diff --git a/test/Microsoft.ML.TestFramework/BytesStreamSource.cs b/test/Microsoft.ML.TestFramework/BytesStreamSource.cs new file mode 100644 index 0000000000..d77c99f848 --- /dev/null +++ b/test/Microsoft.ML.TestFramework/BytesStreamSource.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.IO; +using System.Text; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; + +namespace Microsoft.ML.TestFramework +{ + /// + /// On demand open up new streams over a source of bytes. + /// + public sealed class BytesStreamSource : IMultiStreamSource + { + private readonly byte[] _data; + + public int Count => 1; + + public BytesStreamSource(byte[] data) + { + Contracts.AssertValue(data); + _data = data; + } + + public BytesStreamSource(string data) + : this(Encoding.UTF8.GetBytes(data)) + { + } + + public string GetPathOrNull(int index) + { + Contracts.Check(index == 0); + return null; + } + + public Stream Open(int index) + { + Contracts.Check(index == 0); + return new MemoryStream(_data, writable: false); + } + + public TextReader OpenTextReader(int index) + { + return new StreamReader(Open(index)); + } + } +} diff --git a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs index a402e9dd38..57132f53c8 100644 --- a/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs +++ b/test/Microsoft.ML.TestFramework/DataPipe/TestDataPipeBase.cs @@ -1250,38 +1250,5 @@ protected bool ComparePicture(Picture v1, Picture v2) return true; } #endif - - /// - /// On demand open up new streams over a source of bytes. - /// - protected internal sealed class BytesSource : IMultiStreamSource - { - private readonly byte[] _data; - - public BytesSource(byte[] data) - { - Contracts.AssertValue(data); - _data = data; - } - - public int Count { get { return 1; } } - - public string GetPathOrNull(int index) - { - Contracts.Check(index == 0); - return null; - } - - public Stream Open(int index) - { - Contracts.Check(index == 0); - return new MemoryStream(_data, writable: false); - } - - public TextReader OpenTextReader(int index) - { - return new StreamReader(Open(index)); - } - } } } diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 39301dc429..d5f5759327 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -5,6 +5,7 @@ + diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/ContractsCheckAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs similarity index 99% rename from tools-local/Microsoft.ML.CodeAnalyzer/ContractsCheckAnalyzer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs index 4f96d11e79..a271c7e616 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/ContractsCheckAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckAnalyzer.cs @@ -9,7 +9,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { [DiagnosticAnalyzer(LanguageNames.CSharp)] public sealed class ContractsCheckAnalyzer : DiagnosticAnalyzer diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/ContractsCheckNameofFixProvider.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckNameofFixProvider.cs similarity index 99% rename from tools-local/Microsoft.ML.CodeAnalyzer/ContractsCheckNameofFixProvider.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckNameofFixProvider.cs index d2f6a1a3ea..378082e5a3 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/ContractsCheckNameofFixProvider.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ContractsCheckNameofFixProvider.cs @@ -13,7 +13,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { using Debug = System.Diagnostics.Debug; diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/Descriptions.Designer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Descriptions.Designer.cs similarity index 96% rename from tools-local/Microsoft.ML.CodeAnalyzer/Descriptions.Designer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/Descriptions.Designer.cs index 730cfe6289..4f9067ac6b 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/Descriptions.Designer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Descriptions.Designer.cs @@ -8,7 +8,7 @@ // //------------------------------------------------------------------------------ -namespace Microsoft.ML.CodeAnalyzer { +namespace Microsoft.ML.InternalCodeAnalyzer { using System; using System.Reflection; @@ -40,7 +40,7 @@ internal Descriptions() { internal static global::System.Resources.ResourceManager ResourceManager { get { if (object.ReferenceEquals(resourceMan, null)) { - global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.ML.CodeAnalyzer.Descriptions", typeof(Descriptions).GetTypeInfo().Assembly); + global::System.Resources.ResourceManager temp = new global::System.Resources.ResourceManager("Microsoft.ML.InternalCodeAnalyzer.Descriptions", typeof(Descriptions).GetTypeInfo().Assembly); resourceMan = temp; } return resourceMan; diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/Descriptions.resx b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Descriptions.resx similarity index 100% rename from tools-local/Microsoft.ML.CodeAnalyzer/Descriptions.resx rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/Descriptions.resx diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/InstanceInitializerAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/InstanceInitializerAnalyzer.cs similarity index 98% rename from tools-local/Microsoft.ML.CodeAnalyzer/InstanceInitializerAnalyzer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/InstanceInitializerAnalyzer.cs index f15d2c192e..c7ee67537a 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/InstanceInitializerAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/InstanceInitializerAnalyzer.cs @@ -9,7 +9,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Diagnostics; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { [DiagnosticAnalyzer(LanguageNames.CSharp)] public sealed class InstanceInitializerAnalyzer : DiagnosticAnalyzer diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/Microsoft.ML.CodeAnalyzer.csproj b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Microsoft.ML.InternalCodeAnalyzer.csproj similarity index 73% rename from tools-local/Microsoft.ML.CodeAnalyzer/Microsoft.ML.CodeAnalyzer.csproj rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/Microsoft.ML.InternalCodeAnalyzer.csproj index 46f8e8df15..cb51f63307 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/Microsoft.ML.CodeAnalyzer.csproj +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Microsoft.ML.InternalCodeAnalyzer.csproj @@ -5,9 +5,9 @@ - - - + + + diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/NameAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/NameAnalyzer.cs similarity index 99% rename from tools-local/Microsoft.ML.CodeAnalyzer/NameAnalyzer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/NameAnalyzer.cs index 6e7c77100c..c5f168eb98 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/NameAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/NameAnalyzer.cs @@ -10,7 +10,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { internal enum NameType { diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/NameFixProvider.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/NameFixProvider.cs similarity index 99% rename from tools-local/Microsoft.ML.CodeAnalyzer/NameFixProvider.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/NameFixProvider.cs index 22ff9c383e..4589441e6b 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/NameFixProvider.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/NameFixProvider.cs @@ -17,7 +17,7 @@ using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.Rename; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { // This is somewhat difficult. The trouble is, if a name is in a bad state, it is // actually rather difficult to come up with a general procedure to "fix" it. We diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/ParameterVariableNameAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ParameterVariableNameAnalyzer.cs similarity index 98% rename from tools-local/Microsoft.ML.CodeAnalyzer/ParameterVariableNameAnalyzer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/ParameterVariableNameAnalyzer.cs index 7496609778..04b7373626 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/ParameterVariableNameAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/ParameterVariableNameAnalyzer.cs @@ -9,7 +9,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { [DiagnosticAnalyzer(LanguageNames.CSharp)] public sealed class ParameterVariableNameAnalyzer : DiagnosticAnalyzer diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/SingleVariableDeclarationAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/SingleVariableDeclarationAnalyzer.cs similarity index 98% rename from tools-local/Microsoft.ML.CodeAnalyzer/SingleVariableDeclarationAnalyzer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/SingleVariableDeclarationAnalyzer.cs index eceb2b3f0d..8b7efb8f53 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/SingleVariableDeclarationAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/SingleVariableDeclarationAnalyzer.cs @@ -9,7 +9,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { [DiagnosticAnalyzer(LanguageNames.CSharp)] public sealed class SingleVariableDeclarationAnalyzer : DiagnosticAnalyzer diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/TypeParamNameAnalyzer.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/TypeParamNameAnalyzer.cs similarity index 97% rename from tools-local/Microsoft.ML.CodeAnalyzer/TypeParamNameAnalyzer.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/TypeParamNameAnalyzer.cs index 973c9a7b0b..8ba7de74b5 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/TypeParamNameAnalyzer.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/TypeParamNameAnalyzer.cs @@ -8,7 +8,7 @@ using Microsoft.CodeAnalysis.CSharp.Syntax; using Microsoft.CodeAnalysis.Diagnostics; -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { [DiagnosticAnalyzer(LanguageNames.CSharp)] public sealed class TypeParamNameAnalyzer : DiagnosticAnalyzer diff --git a/tools-local/Microsoft.ML.CodeAnalyzer/Utils.cs b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Utils.cs similarity index 98% rename from tools-local/Microsoft.ML.CodeAnalyzer/Utils.cs rename to tools-local/Microsoft.ML.InternalCodeAnalyzer/Utils.cs index 0bc941906f..2937af07fe 100644 --- a/tools-local/Microsoft.ML.CodeAnalyzer/Utils.cs +++ b/tools-local/Microsoft.ML.InternalCodeAnalyzer/Utils.cs @@ -2,7 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -namespace Microsoft.ML.CodeAnalyzer +namespace Microsoft.ML.InternalCodeAnalyzer { internal static class Utils {