diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs new file mode 100644 index 0000000000..2d5200ed5e --- /dev/null +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -0,0 +1,190 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.ML.Core.Data +{ + /// + /// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema. + /// This is more relaxed than the proper , since it's only a subset of the columns, + /// and also since it doesn't specify exact 's for vectors and keys. + /// + public sealed class SchemaShape + { + public readonly Column[] Columns; + + public sealed class Column + { + public enum VectorKind + { + Scalar, + Vector, + VariableVector + } + + public readonly string Name; + public readonly VectorKind Kind; + public readonly DataKind ItemKind; + public readonly bool IsKey; + public readonly string[] MetadataKinds; + + public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds) + { + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(metadataKinds, nameof(metadataKinds)); + + Name = name; + Kind = vecKind; + ItemKind = itemKind; + IsKey = isKey; + MetadataKinds = metadataKinds; + } + } + + public SchemaShape(Column[] columns) + { + Contracts.CheckValue(columns, nameof(columns)); + Columns = columns; + } + + /// + /// Create a schema shape out of the fully defined schema. + /// + public static SchemaShape Create(ISchema schema) + { + Contracts.CheckValue(schema, nameof(schema)); + var cols = new List(); + + for (int iCol = 0; iCol < schema.ColumnCount; iCol++) + { + if (!schema.IsHidden(iCol)) + { + Column.VectorKind vecKind; + var type = schema.GetColumnType(iCol); + if (type.IsKnownSizeVector) + vecKind = Column.VectorKind.Vector; + else if (type.IsVector) + vecKind = Column.VectorKind.VariableVector; + else + vecKind = Column.VectorKind.Scalar; + + var kind = type.ItemType.RawKind; + var isKey = type.ItemType.IsKey; + + var metadataNames = schema.GetMetadataTypes(iCol) + .Select(kvp => kvp.Key) + .ToArray(); + cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames)); + } + } + return new SchemaShape(cols.ToArray()); + } + + /// + /// Returns the column with a specified , and null if there is no such column. + /// + public Column FindColumn(string name) + { + Contracts.CheckValue(name, nameof(name)); + return Columns.FirstOrDefault(x => x.Name == name); + } + + // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape + // as an input to another schema shape. I started writing, but realized that there's more than one way to check for + // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'. + } + + /// + /// Exception class for schema validation errors. + /// + public class SchemaException : Exception + { + } + + /// + /// The 'data reader' takes a certain kind of input and turns it into an . + /// + /// The type of input the reader takes. + public interface IDataReader + { + /// + /// Produce the data view from the specified input. + /// Note that 's are lazy, so no actual reading happens here, just schema validation. + /// + IDataView Read(TSource input); + + /// + /// The output schema of the reader. + /// + ISchema GetOutputSchema(); + } + + /// + /// Sometimes we need to 'fit' an . + /// A DataReader estimator is the object that does it. + /// + public interface IDataReaderEstimator + where TReader : IDataReader + { + /// + /// Train and return a data reader. + /// + /// REVIEW: you could consider the transformer to take a different , but we don't have such components + /// yet, so why complicate matters? + /// + TReader Fit(TSource input); + + /// + /// The 'promise' of the output schema. + /// It will be used for schema propagation. + /// + SchemaShape GetOutputSchema(); + } + + /// + /// The transformer is a component that transforms data. + /// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'. + /// + public interface ITransformer + { + /// + /// Schema propagation for transformers. + /// Returns the output schema of the data, if the input schema is like the one provided. + /// Throws iff the input schema is not valid for the transformer. + /// + ISchema GetOutputSchema(ISchema inputSchema); + + /// + /// Take the data in, make transformations, output the data. + /// Note that 's are lazy, so no actual transformations happen here, just schema validation. + /// + IDataView Transform(IDataView input); + } + + /// + /// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture + /// a transformer. + /// It also provides the 'schema propagation' like transformers do, but over instead of . + /// + public interface IEstimator + where TTransformer : ITransformer + { + /// + /// Train and return a transformer. + /// + TTransformer Fit(IDataView input); + + /// + /// Schema propagation for estimators. + /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided. + /// Throws iff the input schema is not valid for the estimator. + /// + SchemaShape GetOutputSchema(SchemaShape inputSchema); + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs new file mode 100644 index 0000000000..db4f207fa3 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs @@ -0,0 +1,111 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Model; +using System.IO; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// This class represents a data reader that applies a transformer chain after reading. + /// It also has methods to save itself to a repository. + /// + public sealed class CompositeDataReader : IDataReader + where TLastTransformer : class, ITransformer + { + /// + /// The underlying data reader. + /// + public readonly IDataReader Reader; + /// + /// The chain of transformers (possibly empty) that are applied to data upon reading. + /// + public readonly TransformerChain Transformer; + + public CompositeDataReader(IDataReader reader, TransformerChain transformerChain = null) + { + Contracts.CheckValue(reader, nameof(reader)); + Contracts.CheckValueOrNull(transformerChain); + + Reader = reader; + Transformer = transformerChain ?? new TransformerChain(); + } + + public IDataView Read(TSource input) + { + var idv = Reader.Read(input); + idv = Transformer.Transform(idv); + return idv; + } + + public ISchema GetOutputSchema() + { + var s = Reader.GetOutputSchema(); + return Transformer.GetOutputSchema(s); + } + + /// + /// Append a new transformer to the end. + /// + /// The new composite data reader + public CompositeDataReader AppendTransformer(TNewLast transformer) + where TNewLast : class, ITransformer + { + Contracts.CheckValue(transformer, nameof(transformer)); + + return new CompositeDataReader(Reader, Transformer.Append(transformer)); + } + + /// + /// Save the contents to a stream, as a "model file". + /// + public void SaveTo(IHostEnvironment env, Stream outputStream) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(outputStream, nameof(outputStream)); + + env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save"); + using (var ch = env.Start("Saving pipeline")) + { + using (var rep = RepositoryWriter.CreateNew(outputStream, ch)) + { + ch.Trace("Saving data reader"); + ModelSaveContext.SaveModel(rep, Reader, "Reader"); + + ch.Trace("Saving transformer chain"); + ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature); + rep.Commit(); + } + } + } + } + + /// + /// Utility class to facilitate loading from a stream. + /// + public static class CompositeDataReader + { + /// + /// Load the pipeline from stream. + /// + public static CompositeDataReader LoadFrom(IHostEnvironment env, Stream stream) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(stream, nameof(stream)); + + env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load"); + using (var rep = RepositoryReader.Open(stream, env)) + using (var ch = env.Start("Loading pipeline")) + { + ch.Trace("Loading data reader"); + ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var reader, rep, "Reader"); + + ch.Trace("Loader transformer chain"); + ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature); + return new CompositeDataReader(reader, transformerChain); + } + } + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs new file mode 100644 index 0000000000..49f7f8b99d --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// An estimator class for composite data reader. + /// It can be used to build a 'trainable smart data reader', although this pattern is not very common. + /// + public sealed class CompositeReaderEstimator : IDataReaderEstimator> + where TLastTransformer : class, ITransformer + { + private readonly IDataReaderEstimator> _start; + private readonly EstimatorChain _estimatorChain; + + public CompositeReaderEstimator(IDataReaderEstimator> start, EstimatorChain estimatorChain = null) + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValueOrNull(estimatorChain); + + _start = start; + _estimatorChain = estimatorChain ?? new EstimatorChain(); + + // REVIEW: enforce that estimator chain can read the reader's schema. + // Right now it throws. + // GetOutputSchema(); + } + + public CompositeDataReader Fit(TSource input) + { + var start = _start.Fit(input); + var idv = start.Read(input); + + var xfChain = _estimatorChain.Fit(idv); + return new CompositeDataReader(start, xfChain); + } + + public SchemaShape GetOutputSchema() + { + var shape = _start.GetOutputSchema(); + return _estimatorChain.GetOutputSchema(shape); + } + + /// + /// Append another estimator to the end. + /// + public CompositeReaderEstimator Append(IEstimator estimator) + where TNewTrans : class, ITransformer + { + Contracts.CheckValue(estimator, nameof(estimator)); + + return new CompositeReaderEstimator(_start, _estimatorChain.Append(estimator)); + } + } + +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs new file mode 100644 index 0000000000..727dfa9e0b --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs @@ -0,0 +1,78 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Internal.Utilities; +using System.Linq; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// Represents a chain (potentially empty) of estimators that end with a . + /// If the chain is empty, is always . + /// + public sealed class EstimatorChain : IEstimator> + where TLastTransformer : class, ITransformer + { + private readonly TransformerScope[] _scopes; + private readonly IEstimator[] _estimators; + public readonly IEstimator LastEstimator; + + private EstimatorChain(IEstimator[] estimators, TransformerScope[] scopes) + { + Contracts.AssertValueOrNull(estimators); + Contracts.AssertValueOrNull(scopes); + Contracts.Assert(Utils.Size(estimators) == Utils.Size(scopes)); + + _estimators = estimators ?? new IEstimator[0]; + _scopes = scopes ?? new TransformerScope[0]; + LastEstimator = estimators.LastOrDefault() as IEstimator; + + Contracts.Assert((_estimators.Length > 0) == (LastEstimator != null)); + } + + /// + /// Create an empty estimator chain. + /// + public EstimatorChain() + { + _estimators = new IEstimator[0]; + _scopes = new TransformerScope[0]; + LastEstimator = null; + } + + public TransformerChain Fit(IDataView input) + { + // REVIEW: before fitting, run schema propagation. + // Currently, it throws. + // GetOutputSchema(SchemaShape.Create(input.Schema); + + IDataView current = input; + var xfs = new ITransformer[_estimators.Length]; + for (int i = 0; i < _estimators.Length; i++) + { + var est = _estimators[i]; + xfs[i] = est.Fit(current); + current = xfs[i].Transform(current); + } + + return new TransformerChain(xfs, _scopes); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + var s = inputSchema; + foreach (var est in _estimators) + s = est.GetOutputSchema(s); + return s; + } + + public EstimatorChain Append(IEstimator estimator, TransformerScope scope = TransformerScope.Everything) + where TNewTrans : class, ITransformer + { + Contracts.CheckValue(estimator, nameof(estimator)); + return new EstimatorChain(_estimators.Append(estimator).ToArray(), _scopes.Append(scope).ToArray()); + } + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs new file mode 100644 index 0000000000..d862c69028 --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// Extension methods that allow chaining estimator and transformer pipes together. + /// + public static class LearningPipelineExtensions + { + /// + /// Create a composite reader estimator by appending an estimator to a reader estimator. + /// + public static CompositeReaderEstimator Append( + this IDataReaderEstimator> start, IEstimator estimator) + where TTrans : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(estimator, nameof(estimator)); + + return new CompositeReaderEstimator(start).Append(estimator); + } + + /// + /// Create an estimator chain by appending an estimator to an estimator. + /// + public static EstimatorChain Append( + this IEstimator start, IEstimator estimator, + TransformerScope scope = TransformerScope.Everything) + where TTrans : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(estimator, nameof(estimator)); + + return new EstimatorChain().Append(start).Append(estimator, scope); + } + + /// + /// Create a composite reader by appending a transformer to a data reader. + /// + public static CompositeDataReader Append(this IDataReader reader, TTrans transformer) + where TTrans : class, ITransformer + { + Contracts.CheckValue(reader, nameof(reader)); + Contracts.CheckValue(transformer, nameof(transformer)); + + return new CompositeDataReader(reader).AppendTransformer(transformer); + } + + /// + /// Create a transformer chain by appending a transformer to a transformer. + /// + public static TransformerChain Append(this ITransformer start, TTrans transformer) + where TTrans : class, ITransformer + { + Contracts.CheckValue(start, nameof(start)); + Contracts.CheckValue(transformer, nameof(transformer)); + + return new TransformerChain(start, transformer); + } + } +} diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs new file mode 100644 index 0000000000..d5f246689a --- /dev/null +++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs @@ -0,0 +1,227 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Internal.Utilities; +using Microsoft.ML.Runtime.Model; +using System; +using System.Collections; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +[assembly: LoadableClass(typeof(TransformerChain), typeof(TransformerChain), null, typeof(SignatureLoadModel), + "Transformer chain", TransformerChain.LoaderSignature)] + +namespace Microsoft.ML.Runtime.Data +{ + /// + /// This enum allows for 'tagging' the estimators (and subsequently transformers) in the chain to be used + /// 'only for training', 'for training and evaluation' etc. + /// Most notable example is, transformations over the label column should not be used for scoring, so the scope + /// should be or . + /// + [Flags] + public enum TransformerScope + { + None = 0, + Training = 1 << 0, + Testing = 1 << 1, + Scoring = 1 << 2, + TrainTest = Training | Testing, + Everything = Training | Testing | Scoring + } + + /// + /// A chain of transformers (possibly empty) that end with a . + /// For an empty chain, is always . + /// + public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable + where TLastTransformer : class, ITransformer + { + private readonly ITransformer[] _transformers; + private readonly TransformerScope[] _scopes; + public readonly TLastTransformer LastTransformer; + + private const string TransformDirTemplate = "Transform_{0:000}"; + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "XF CHAIN", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: TransformerChain.LoaderSignature); + } + + /// + /// Create a transformer chain by specifying transformers and their scopes. + /// + /// Transformers to be chained. + /// Transformer scopes, parallel to . + public TransformerChain(IEnumerable transformers, IEnumerable scopes) + { + Contracts.CheckValueOrNull(transformers); + Contracts.CheckValueOrNull(scopes); + + _transformers = transformers?.ToArray() ?? new ITransformer[0]; + _scopes = scopes?.ToArray() ?? new TransformerScope[0]; + LastTransformer = transformers.LastOrDefault() as TLastTransformer; + + Contracts.Check((_transformers.Length > 0) == (LastTransformer != null)); + Contracts.Check(_transformers.Length == _scopes.Length); + } + + /// + /// Create a transformer chain by specifying all the transformers. The scopes are assumed to be + /// . + /// + /// + public TransformerChain(params ITransformer[] transformers) + { + Contracts.CheckValueOrNull(transformers); + + if (Utils.Size(transformers) == 0) + { + _transformers = new ITransformer[0]; + _scopes = new TransformerScope[0]; + LastTransformer = null; + } + else + { + _transformers = transformers.ToArray(); + _scopes = transformers.Select(x => TransformerScope.Everything).ToArray(); + LastTransformer = transformers.Last() as TLastTransformer; + Contracts.Check(LastTransformer != null); + } + } + + public ISchema GetOutputSchema(ISchema inputSchema) + { + Contracts.CheckValue(inputSchema, nameof(inputSchema)); + + var s = inputSchema; + foreach (var xf in _transformers) + s = xf.GetOutputSchema(s); + return s; + } + + public IDataView Transform(IDataView input) + { + Contracts.CheckValue(input, nameof(input)); + + // Trigger schema propagation prior to transforming. + // REVIEW: does this actually constitute 'early warning', given that Transform call is lazy anyway? + GetOutputSchema(input.Schema); + + var dv = input; + foreach (var xf in _transformers) + dv = xf.Transform(dv); + return dv; + } + + public TransformerChain GetModelFor(TransformerScope scopeFilter) + { + var xfs = new List(); + var scopes = new List(); + for (int i = 0; i < _transformers.Length; i++) + { + if ((_scopes[i] & scopeFilter) != TransformerScope.None) + { + xfs.Add(_transformers[i]); + scopes.Add(_scopes[i]); + } + } + return new TransformerChain(xfs.ToArray(), scopes.ToArray()); + } + + public TransformerChain Append(TNewLast transformer, TransformerScope scope = TransformerScope.Everything) + where TNewLast : class, ITransformer + { + Contracts.CheckValue(transformer, nameof(transformer)); + return new TransformerChain(_transformers.Append(transformer).ToArray(), _scopes.Append(scope).ToArray()); + } + + public void Save(ModelSaveContext ctx) + { + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + ctx.Writer.Write(_transformers.Length); + + for (int i = 0; i < _transformers.Length; i++) + { + ctx.Writer.Write((int)_scopes[i]); + var dirName = string.Format(TransformDirTemplate, i); + ctx.SaveModel(_transformers[i], dirName); + } + } + + /// + /// The loading constructor of transformer chain. Reverse of . + /// + internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx) + { + int len = ctx.Reader.ReadInt32(); + _transformers = new ITransformer[len]; + _scopes = new TransformerScope[len]; + for (int i = 0; i < len; i++) + { + _scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32()); + var dirName = string.Format(TransformDirTemplate, i); + ctx.LoadModel(env, out _transformers[i], dirName); + } + if (len > 0) + LastTransformer = _transformers[len - 1] as TLastTransformer; + else + LastTransformer = null; + } + + public void SaveTo(IHostEnvironment env, Stream outputStream) + { + using (var ch = env.Start("Saving pipeline")) + { + using (var rep = RepositoryWriter.CreateNew(outputStream, ch)) + { + ch.Trace("Saving transformer chain"); + ModelSaveContext.SaveModel(rep, this, TransformerChain.LoaderSignature); + rep.Commit(); + } + } + } + + public IEnumerator GetEnumerator() => ((IEnumerable)_transformers).GetEnumerator(); + + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); + } + + /// + /// Saving/loading routines for transformer chains. + /// + public static class TransformerChain + { + public const string LoaderSignature = "TransformerChain"; + + public static TransformerChain Create(IHostEnvironment env, ModelLoadContext ctx) + => new TransformerChain(env, ctx); + + /// + /// Save any transformer to a stream by wrapping it into a transformer chain. + /// + public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream) + => new TransformerChain(transformer).SaveTo(env, outputStream); + + public static TransformerChain LoadFrom(IHostEnvironment env, Stream stream) + { + using (var rep = RepositoryReader.Open(stream, env)) + { + ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature); + return transformerChain; + } + } + } +} diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj index 4810872906..be16385901 100644 --- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj +++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj @@ -19,6 +19,7 @@ + diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj index 2d8013f545..d8d5cd3d5d 100644 --- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj +++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj @@ -1,4 +1,9 @@  + + + + + @@ -22,4 +27,8 @@ + + + + \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs new file mode 100644 index 0000000000..9ba601ad82 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs @@ -0,0 +1,64 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +// This file contains code examples that currently do not even compile. +// They serve as the reference point of the 'desired user-facing API' for the future work. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public class AspirationalExamples + { + public class IrisPrediction + { + public string PredictedLabel; + } + + public class IrisExample + { + public float SepalWidth { get; set; } + public float SepalLength { get; set; } + public float PetalWidth { get; set; } + public float PetalLength { get; set; } + } + + public void FirstExperienceWithML() + { + // Load the data into the system. + string dataPath = "iris-data.txt"; + var data = TextReader.FitAndRead(env, dataPath, row => ( + Label: row.ReadString(0), + SepalWidth: row.ReadFloat(1), + SepalLength: row.ReadFloat(2), + PetalWidth: row.ReadFloat(3), + PetalLength: row.ReadFloat(4))); + + + var preprocess = data.Schema.MakeEstimator(row => ( + // Convert string label to key. + Label: row.Label.DictionarizeLabel(), + // Concatenate all features into a vector. + Features: row.SepalWidth.ConcatWith(row.SepalLength, row.PetalWidth, row.PetalLength))); + + var pipeline = preprocess + // Append the trainer to the training pipeline. + .AppendEstimator(row => row.Label.PredictWithSdca(row.Features)) + .AppendEstimator(row => row.PredictedLabel.KeyToValue()); + + // Train the model and make some predictions. + var model = pipeline.Fit(data); + + IrisPrediction prediction = model.Predict(new IrisExample + { + SepalWidth = 3.3f, + SepalLength = 1.6f, + PetalWidth = 0.2f, + PetalLength = 5.1f + }); + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs new file mode 100644 index 0000000000..fb6c56674b --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs @@ -0,0 +1,49 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Cross-validation: Have a mechanism to do cross validation, that is, you come up with + /// a data source (optionally with stratification column), come up with an instantiable transform + /// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate + /// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of + /// evaluations and optionally trained pipes. (People always want metrics out of xfold, + /// they sometimes want the actual models too.) + /// + [Fact] + void New_CrossValidation() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + + var data = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .FitAndRead(new MultiFileSource(dataPath)); + // Pipeline. + var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs()) + .Append(new MySdca(env, new LinearClassificationTrainer.Arguments + { + NumThreads = 1, + ConvergenceTolerance = 1f + }, "Features", "Label")); + + var cv = new MyCrossValidation.BinaryCrossValidator(env) + { + NumFolds = 2 + }; + + var cvResult = cv.CrossValidate(data, pipeline); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs new file mode 100644 index 0000000000..eb8a6991c5 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs @@ -0,0 +1,50 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Decomposable train and predict: Train on Iris multiclass problem, which will require + /// a transform on labels. Be able to reconstitute the pipeline for a prediction only task, + /// which will essentially "drop" the transform over labels, while retaining the property + /// that the predicted label for this has a key-type, the probability outputs for the classes + /// have the class labels as slot names, etc. This should be do-able without ugly compromises like, + /// say, injecting a dummy label. + /// + [Fact] + void New_DecomposableTrainAndPredict() + { + var dataPath = GetDataPath(IrisDataPath); + using (var env = new TlcEnvironment()) + { + var data = new MyTextLoader(env, MakeIrisTextLoaderArgs()) + .FitAndRead(new MultiFileSource(dataPath)); + + var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest) + .Append(new MySdcaMulticlass(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) + .Append(new MyKeyToValueTransform(env, "PredictedLabel")); + + var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); + var engine = new MyPredictionEngine(env, model); + + var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var testData = testLoader.AsEnumerable(env, false); + foreach (var input in testData.Take(20)) + { + var prediction = engine.Predict(input); + Assert.True(prediction.PredictedLabel == input.Label); + } + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs new file mode 100644 index 0000000000..dcb7f515c2 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs @@ -0,0 +1,42 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Evaluation: Similar to the simple train scenario, except instead of having some + /// predictive structure, be able to score another "test" data file, run the result + /// through an evaluator and get metrics like AUC, accuracy, PR curves, and whatnot. + /// Getting metrics out of this shoudl be as straightforward and unannoying as possible. + /// + [Fact] + public void New_Evaluation() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())) + .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); + + // Train. + var model = pipeline.Fit(new MultiFileSource(dataPath)); + + // Evaluate on the test set. + var dataEval = model.Read(new MultiFileSource(testDataPath)); + var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { }); + var metrics = evaluator.Evaluate(dataEval); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs new file mode 100644 index 0000000000..53cfba2fbf --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs @@ -0,0 +1,58 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Extensibility: We can't possibly write every conceivable transform and should not try. + /// It should somehow be possible for a user to inject custom code to, say, transform data. + /// This might have a much steeper learning curve than the other usages (which merely involve + /// usage of already established components), but should still be possible. + /// + [Fact] + void New_Extensibility() + { + var dataPath = GetDataPath(IrisDataPath); + using (var env = new TlcEnvironment()) + { + var data = new MyTextLoader(env, MakeIrisTextLoaderArgs()) + .FitAndRead(new MultiFileSource(dataPath)); + + Action action = (i, j) => + { + j.Label = i.Label; + j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength; + j.PetalWidth = i.PetalWidth; + j.SepalLength = i.SepalLength; + j.SepalWidth = i.SepalWidth; + }; + var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(new MyLambdaTransform(env, action)) + .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest) + .Append(new MySdcaMulticlass(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label")) + .Append(new MyKeyToValueTransform(env, "PredictedLabel")); + + var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring); + var engine = new MyPredictionEngine(env, model); + + var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); + var testData = testLoader.AsEnumerable(env, false); + foreach (var input in testData.Take(20)) + { + var prediction = engine.Predict(input); + Assert.True(prediction.PredictedLabel == input.Label); + } + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs new file mode 100644 index 0000000000..6e355f8db2 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs @@ -0,0 +1,47 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Learners; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// File-based saving of data: Come up with transform pipeline. Transform training and + /// test data, and save the featurized data to some file, using the .idv format. + /// Train and evaluate multiple models over that pre-featurized data. (Useful for + /// sweeping scenarios, where you are training many times on the same data, + /// and don't necessarily want to transform it every single time.) + /// + [Fact] + void New_FileBasedSavingOfData() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())); + + var trainData = pipeline.Fit(new MultiFileSource(dataPath)).Read(new MultiFileSource(dataPath)); + + using (var file = env.CreateOutputFile("i.idv")) + trainData.SaveAsBinary(env, file.CreateWriteStream()); + + var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"); + var loadedTrainData = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv")); + + // Train. + var model = trainer.Train(loadedTrainData); + DeleteOutputPath("i.idv"); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs new file mode 100644 index 0000000000..ae8f36183d --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs @@ -0,0 +1,55 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.FastTree; +using Microsoft.ML.Runtime.Internal.Calibration; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.TextAnalytics; +using System.Collections.Generic; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + + public partial class ApiScenariosTests + { + /// + /// Introspective training: Models that produce outputs and are otherwise black boxes are of limited use; + /// it is also necessary often to understand at least to some degree what was learnt. To outline critical + /// scenarios that have come up multiple times: + /// *) When I train a linear model, I should be able to inspect coefficients. + /// *) The tree ensemble learners, I should be able to inspect the trees. + /// *) The LDA transform, I should be able to inspect the topics. + /// I view it as essential from a usability perspective that this be discoverable to someone without + /// having to read documentation.E.g.: if I have var lda = new LdaTransform().Fit(data)(I don't insist on that + /// exact signature, just giving the idea), then if I were to type lda. + /// In Visual Studio, one of the auto-complete targets should be something like GetTopics. + /// + + [Fact] + public void New_IntrospectiveTraining() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + var data = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .FitAndRead(new MultiFileSource(dataPath)); + + var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs()) + .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); + + // Train. + var model = pipeline.Fit(data); + + // Get feature weights. + VBuffer weights = default; + model.LastTransformer.InnerModel.GetFeatureWeights(ref weights); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs new file mode 100644 index 0000000000..e6ed7469bb --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs @@ -0,0 +1,40 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Meta-components: Meta-components (e.g., components that themselves instantiate components) should not be booby-trapped. + /// When specifying what trainer OVA should use, a user will be able to specify any binary classifier. + /// If they specify a regression or multi-class classifier ideally that should be a compile error. + /// + [Fact] + public void New_Metacomponents() + { + var dataPath = GetDataPath(IrisDataPath); + using (var env = new TlcEnvironment()) + { + var data = new MyTextLoader(env, MakeIrisTextLoaderArgs()) + .FitAndRead(new MultiFileSource(dataPath)); + + var sdcaTrainer = new MySdca(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"); + var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") + .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest) + .Append(new MyOva(env, sdcaTrainer)) + .Append(new MyKeyToValueTransform(env, "PredictedLabel")); + + var model = pipeline.Fit(data); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs new file mode 100644 index 0000000000..e8386d166e --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs @@ -0,0 +1,57 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Multi-threaded prediction. A twist on "Simple train and predict", where we account that + /// multiple threads may want predictions at the same time. Because we deliberately do not + /// reallocate internal memory buffers on every single prediction, the PredictionEngine + /// (or its estimator/transformer based successor) is, like most stateful .NET objects, + /// fundamentally not thread safe. This is deliberate and as designed. However, some mechanism + /// to enable multi-threaded scenarios (e.g., a web server servicing requests) should be possible + /// and performant in the new API. + /// + [Fact] + void New_MultithreadedPrediction() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())) + .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); + + // Train. + var model = pipeline.Fit(new MultiFileSource(dataPath)); + + // Create prediction engine and test predictions. + var engine = new MyPredictionEngine(env, model.Transformer); + + // Take a couple examples out of the test data and run predictions on top. + var testData = model.Reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) + .AsEnumerable(env, false); + + Parallel.ForEach(testData, (input) => + { + lock (engine) + { + var prediction = engine.Predict(input); + } + }); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs new file mode 100644 index 0000000000..12a6a2db20 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs @@ -0,0 +1,52 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Models; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier, + /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold + /// and configures the scorer (or more precisely instantiates a new scorer over the same predictor) + /// with some threshold derived from that. + /// + [Fact] + public void New_ReconfigurablePrediction() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + var dataReader = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Fit(new MultiFileSource(dataPath)); + + var data = dataReader.Read(new MultiFileSource(dataPath)); + var testData = dataReader.Read(new MultiFileSource(testDataPath)); + + // Pipeline. + var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs()) + .Fit(data); + + var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"); + var trainData = pipeline.Transform(data); + var model = trainer.Fit(trainData); + + var scoredTest = model.Transform(pipeline.Transform(testData)); + var metrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()).Evaluate(scoredTest, "Label", "Probability"); + + var newModel = model.Clone(new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability }); + var newScoredTest = newModel.Transform(pipeline.Transform(testData)); + var newMetrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments { Threshold = 0.01f, UseRawScoreThreshold = false }).Evaluate(newScoredTest, "Label", "Probability"); + } + + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs new file mode 100644 index 0000000000..76c12e6068 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -0,0 +1,53 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Xunit; +using System.Linq; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Start with a dataset in a text file. Run text featurization on text values. + /// Train a linear model over that. (I am thinking sentiment classification.) + /// Out of the result, produce some structure over which you can get predictions programmatically + /// (e.g., the prediction does not happen over a file as it did during training). + /// + [Fact] + public void New_SimpleTrainAndPredict() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())) + .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); + + // Train. + var model = pipeline.Fit(new MultiFileSource(dataPath)); + + // Create prediction engine and test predictions. + var engine = new MyPredictionEngine(env, model.Transformer); + + // Take a couple examples out of the test data and run predictions on top. + var testData = model.Reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) + .AsEnumerable(env, false); + foreach (var input in testData.Take(5)) + { + var prediction = engine.Predict(input); + // Verify that predictions match and scores are separated from zero. + Assert.Equal(input.Sentiment, prediction.Sentiment); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); + } + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs new file mode 100644 index 0000000000..f5b5b98c90 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -0,0 +1,65 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using System.Linq; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Train, save/load model, predict: + /// Serve the scenario where training and prediction happen in different processes (or even different machines). + /// The actual test will not run in different processes, but will simulate the idea that the + /// "communication pipe" is just a serialized model of some form. + /// + [Fact] + public void New_TrainSaveModelAndPredict() + { + var dataPath = GetDataPath(SentimentDataPath); + var testDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())) + .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label")); + + // Train. + var model = pipeline.Fit(new MultiFileSource(dataPath)); + + ITransformer loadedModel; + using (var file = env.CreateTempFile()) + { + // Save model. + using (var fs = file.CreateWriteStream()) + model.Transformer.SaveTo(env, fs); + + // Load model. + loadedModel = TransformerChain.LoadFrom(env, file.OpenReadStream()); + } + + // Create prediction engine and test predictions. + var engine = new MyPredictionEngine(env, loadedModel); + + // Take a couple examples out of the test data and run predictions on top. + var testData = model.Reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath))) + .AsEnumerable(env, false); + foreach (var input in testData.Take(5)) + { + var prediction = engine.Predict(input); + // Verify that predictions match and scores are separated from zero. + Assert.Equal(input.Sentiment, prediction.Sentiment); + Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1); + } + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs new file mode 100644 index 0000000000..2e687c9138 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -0,0 +1,46 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Learners; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Train with initial predictor: Similar to the simple train scenario, but also accept a pre-trained initial model. + /// The scenario might be one of the online linear learners that can take advantage of this, e.g., averaged perceptron. + /// + [Fact] + public void New_TrainWithInitialPredictor() + { + var dataPath = GetDataPath(SentimentDataPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())); + + // Train the pipeline, prepare train set. + var reader = pipeline.Fit(new MultiFileSource(dataPath)); + var trainData = reader.Read(new MultiFileSource(dataPath)); + + + // Train the first predictor. + var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments + { + NumThreads = 1 + }, "Features", "Label"); + var firstModel = trainer.Fit(trainData); + + // Train the second predictor on the same data. + var secondTrainer = new MyAveragedPerceptron(env, new AveragedPerceptronTrainer.Arguments(), "Features", "Label"); + var finalModel = secondTrainer.Train(trainData, firstModel.InnerModel); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs new file mode 100644 index 0000000000..25eebafc45 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Train with validation set: Similar to the simple train scenario, but also support a validation set. + /// The learner might be trees with early stopping. + /// + [Fact] + public void New_TrainWithValidationSet() + { + var dataPath = GetDataPath(SentimentDataPath); + var validationDataPath = GetDataPath(SentimentTestPath); + + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + // Pipeline. + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())); + + // Train the pipeline, prepare train and validation set. + var reader = pipeline.Fit(new MultiFileSource(dataPath)); + var trainData = reader.Read(new MultiFileSource(dataPath)); + var validData = reader.Read(new MultiFileSource(validationDataPath)); + + // Train model with validation set. + var trainer = new MySdca(env, new Runtime.Learners.LinearClassificationTrainer.Arguments(), "Features", "Label"); + var model = trainer.Train(trainData, validData); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs new file mode 100644 index 0000000000..d0f79cc1f1 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs @@ -0,0 +1,59 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; +using Xunit; + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + public partial class ApiScenariosTests + { + /// + /// Visibility: It should, possibly through the debugger, be not such a pain to actually + /// see what is happening to your data when you apply this or that transform. E.g.: if I + /// were to have the text "Help I'm a bug!" I should be able to see the steps where it is + /// normalized to "help i'm a bug" then tokenized into ["help", "i'm", "a", "bug"] then + /// mapped into term numbers [203, 25, 3, 511] then projected into the sparse + /// float vector {3:1, 25:1, 203:1, 511:1}, etc. etc. + /// + [Fact] + void New_Visibility() + { + var dataPath = GetDataPath(SentimentDataPath); + using (var env = new TlcEnvironment(seed: 1, conc: 1)) + { + var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs()) + .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs())); + var data = pipeline.FitAndRead(new MultiFileSource(dataPath)); + // In order to find out available column names, you can go through schema and check + // column names and appropriate type for getter. + for (int i = 0; i < data.Schema.ColumnCount; i++) + { + var columnName = data.Schema.GetColumnName(i); + var columnType = data.Schema.GetColumnType(i).RawType; + } + + using (var cursor = data.GetRowCursor(x => true)) + { + Assert.True(cursor.Schema.TryGetColumnIndex("SentimentText", out int textColumn)); + Assert.True(cursor.Schema.TryGetColumnIndex("Features_TransformedText", out int transformedTextColumn)); + Assert.True(cursor.Schema.TryGetColumnIndex("Features", out int featureColumn)); + + var originalTextGettter = cursor.GetGetter(textColumn); + var transformedTextGettter = cursor.GetGetter>(transformedTextColumn); + var featureGettter = cursor.GetGetter>(featureColumn); + DvText text = default; + VBuffer transformedText = default; + VBuffer features = default; + while (cursor.MoveNext()) + { + originalTextGettter(ref text); + transformedTextGettter(ref transformedText); + featureGettter(ref features); + } + } + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs new file mode 100644 index 0000000000..1c2156b3e5 --- /dev/null +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs @@ -0,0 +1,685 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Core.Data; +using Microsoft.ML.Models; +using Microsoft.ML.Runtime; +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.CommandLine; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.EntryPoints; +using Microsoft.ML.Runtime.Internal.Internallearn; +using Microsoft.ML.Runtime.Learners; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.Training; +using Microsoft.ML.Tests.Scenarios.Api; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; + +[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel), + "Transform wrapper", TransformWrapper.LoaderSignature)] +[assembly: LoadableClass(typeof(LoaderWrapper), null, typeof(SignatureLoadModel), + "Loader wrapper", LoaderWrapper.LoaderSignature)] + +namespace Microsoft.ML.Tests.Scenarios.Api +{ + using TScalarPredictor = IPredictorProducing; + using TWeightsPredictor = IPredictorWithFeatureWeights; + + public sealed class LoaderWrapper : IDataReader, ICanSaveModel + { + public const string LoaderSignature = "LoaderWrapper"; + + private readonly IHostEnvironment _env; + private readonly Func _loaderFactory; + + public LoaderWrapper(IHostEnvironment env, Func loaderFactory) + { + _env = env; + _loaderFactory = loaderFactory; + } + + public ISchema GetOutputSchema() + { + var emptyData = Read(new MultiFileSource(null)); + return emptyData.Schema; + } + + public IDataView Read(IMultiStreamSource input) => _loaderFactory(input); + + public void Save(ModelSaveContext ctx) + { + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + var ldr = Read(new MultiFileSource(null)); + ctx.SaveModel(ldr, "Loader"); + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "LDR WRPR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx) + { + ctx.CheckAtModel(GetVersionInfo()); + ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); + + var loaderStream = new MemoryStream(); + using (var rep = RepositoryWriter.CreateNew(loaderStream)) + { + ModelSaveContext.SaveModel(rep, loader, "Loader"); + rep.Commit(); + } + + _env = env; + _loaderFactory = (IMultiStreamSource source) => + { + using (var rep = RepositoryReader.Open(loaderStream)) + { + ModelLoadContext.LoadModel(env, out var ldr, rep, "Loader", source); + return ldr; + } + }; + + } + } + + public class TransformWrapper : ITransformer, ICanSaveModel + { + public const string LoaderSignature = "TransformWrapper"; + private const string TransformDirTemplate = "Step_{0:000}"; + + protected readonly IHostEnvironment _env; + protected readonly IDataView _xf; + + public TransformWrapper(IHostEnvironment env, IDataView xf) + { + _env = env; + _xf = xf; + } + + public ISchema GetOutputSchema(ISchema inputSchema) + { + var dv = new EmptyDataView(_env, inputSchema); + var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv); + return output.Schema; + } + + public void Save(ModelSaveContext ctx) + { + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + var dataPipe = _xf; + var transforms = new List(); + while (dataPipe is IDataTransform xf) + { + // REVIEW: a malicious user could construct a loop in the Source chain, that would + // cause this method to iterate forever (and throw something when the list overflows). There's + // no way to insulate from ALL malicious behavior. + transforms.Add(xf); + dataPipe = xf.Source; + Contracts.AssertValue(dataPipe); + } + transforms.Reverse(); + + ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema)); + + ctx.Writer.Write(transforms.Count); + for (int i = 0; i < transforms.Count; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.SaveModel(transforms[i], dirName); + } + } + + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "XF WRPR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature); + } + + public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx) + { + ctx.CheckAtModel(GetVersionInfo()); + int n = ctx.Reader.ReadInt32(); + + ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); + + IDataView data = loader; + for (int i = 0; i < n; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.LoadModel(env, out var xf, dirName, data); + data = xf; + } + + _env = env; + _xf = data; + } + + public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input); + } + + public interface IPredictorTransformer : ITransformer + { + TModel InnerModel { get; } + } + + public class ScorerWrapper : TransformWrapper, IPredictorTransformer + where TModel : IPredictor + { + protected readonly string _featureColumn; + + public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string featureColumn) + : base(env, scorer) + { + _featureColumn = featureColumn; + InnerModel = trainedModel; + } + + public TModel InnerModel { get; } + } + + public class BinaryScorerWrapper : ScorerWrapper + where TModel : IPredictor + { + public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args) + : base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn) + { + } + + private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args) + { + var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}"; + + var scorerFactorySettings = CmdParser.CreateComponentFactory( + typeof(IComponentFactory), + typeof(SignatureDataScorer), + settings); + + var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings); + var edv = new EmptyDataView(env, schema); + var data = new RoleMappedData(edv, "Label", featureColumn, opt: true); + + return new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema); + } + + public BinaryScorerWrapper Clone(BinaryClassifierScorer.Arguments scorerArgs) + { + var scorer = _xf as IDataScorerTransform; + return new BinaryScorerWrapper(_env, InnerModel, scorer.Source.Schema, _featureColumn, scorerArgs); + } + } + + public class MyTextLoader : IDataReaderEstimator + { + private readonly TextLoader.Arguments _args; + private readonly IHostEnvironment _env; + + public MyTextLoader(IHostEnvironment env, TextLoader.Arguments args) + { + _env = env; + _args = args; + } + + public LoaderWrapper Fit(IMultiStreamSource input) + { + return new LoaderWrapper(_env, x => new TextLoader(_env, _args, x)); + } + + public SchemaShape GetOutputSchema() + { + var emptyData = new TextLoader(_env, _args, new MultiFileSource(null)); + return SchemaShape.Create(emptyData.Schema); + } + } + + public interface ITrainerEstimator: IEstimator + where TTransformer: IPredictorTransformer + where TModel: IPredictor + { + TrainerInfo TrainerInfo { get; } + } + + public abstract class TrainerBase : ITrainerEstimator + where TTransformer : ScorerWrapper + where TModel : IPredictor + { + protected readonly IHostEnvironment _env; + protected readonly string _featureCol; + protected readonly string _labelCol; + + public TrainerInfo TrainerInfo { get; } + + protected TrainerBase(IHostEnvironment env, TrainerInfo trainerInfo, string featureColumn, string labelColumn) + { + _env = env; + _featureCol = featureColumn; + _labelCol = labelColumn; + TrainerInfo = trainerInfo; + } + + public TTransformer Fit(IDataView input) + { + return TrainTransformer(input); + } + + protected TTransformer TrainTransformer(IDataView trainSet, + IDataView validationSet = null, IPredictor initPredictor = null) + { + var cachedTrain = TrainerInfo.WantCaching ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet; + + var trainRoles = new RoleMappedData(cachedTrain, label: _labelCol, feature: _featureCol); + var emptyData = new EmptyDataView(_env, trainSet.Schema); + IDataView normalizer = emptyData; + + if (TrainerInfo.NeedNormalization && trainRoles.Schema.FeaturesAreNormalized() == false) + { + var view = NormalizeTransform.CreateMinMaxNormalizer(_env, trainRoles.Data, name: trainRoles.Schema.Feature.Name); + normalizer = ApplyTransformUtils.ApplyAllTransformsToData(_env, view, emptyData, cachedTrain); + + trainRoles = new RoleMappedData(view, trainRoles.Schema.GetColumnRoleNames()); + } + + RoleMappedData validRoles; + + if (validationSet == null) + validRoles = null; + else + { + var cachedValid = TrainerInfo.WantCaching ? new CacheDataView(_env, validationSet, prefetch: null) : validationSet; + cachedValid = ApplyTransformUtils.ApplyAllTransformsToData(_env, normalizer, cachedValid); + validRoles = new RoleMappedData(cachedValid, label: _labelCol, feature: _featureCol); + } + + var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor)); + + var scoreRoles = new RoleMappedData(normalizer, label: _labelCol, feature: _featureCol); + return MakeScorer(pred, scoreRoles); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + + protected abstract TModel TrainCore(TrainContext trainContext); + + protected abstract TTransformer MakeScorer(TModel predictor, RoleMappedData data); + + protected ScorerWrapper MakeScorerBasic(TModel predictor, RoleMappedData data) + { + var scorer = ScoreUtils.GetScorer(predictor, data, _env, data.Schema); + return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, data.Schema.Feature.Name)); + } + } + + public class MyTextTransform : IEstimator + { + private readonly IHostEnvironment _env; + private readonly TextTransform.Arguments _args; + + public MyTextTransform(IHostEnvironment env, TextTransform.Arguments args) + { + _env = env; + _args = args; + } + + public TransformWrapper Fit(IDataView input) + { + var xf = TextTransform.Create(_env, _args, input); + var empty = new EmptyDataView(_env, input.Schema); + var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); + return new TransformWrapper(_env, chunk); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + } + + public class MyTermTransform : IEstimator + { + private readonly IHostEnvironment _env; + private readonly string _column; + private readonly string _srcColumn; + + public MyTermTransform(IHostEnvironment env, string column, string srcColumn = null) + { + _env = env; + _column = column; + _srcColumn = srcColumn; + } + + public TransformWrapper Fit(IDataView input) + { + var xf = new TermTransform(_env, input, _column, _srcColumn); + var empty = new EmptyDataView(_env, input.Schema); + var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); + return new TransformWrapper(_env, chunk); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + } + + public class MyConcatTransform : IEstimator + { + private readonly IHostEnvironment _env; + private readonly string _name; + private readonly string[] _source; + + public MyConcatTransform(IHostEnvironment env, string name, params string[] source) + { + _env = env; + _name = name; + _source = source; + } + + public TransformWrapper Fit(IDataView input) + { + var xf = new ConcatTransform(_env, input, _name, _source); + var empty = new EmptyDataView(_env, input.Schema); + var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); + return new TransformWrapper(_env, chunk); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + } + + public class MyKeyToValueTransform : IEstimator + { + private readonly IHostEnvironment _env; + private readonly string _name; + private readonly string _source; + + public MyKeyToValueTransform(IHostEnvironment env, string name, string source = null) + { + _env = env; + _name = name; + _source = source; + } + + public TransformWrapper Fit(IDataView input) + { + var xf = new KeyToValueTransform(_env, input, _name, _source); + var empty = new EmptyDataView(_env, input.Schema); + var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); + return new TransformWrapper(_env, chunk); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + } + + public sealed class MySdca : TrainerBase, TWeightsPredictor> + { + private readonly LinearClassificationTrainer.Arguments _args; + + public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args, string featureCol, string labelCol) + : base(env, new TrainerInfo(), featureCol, labelCol) + { + _args = args; + } + + protected override TWeightsPredictor TrainCore(TrainContext context) => new LinearClassificationTrainer(_env, _args).Train(context); + + public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData); + + protected override BinaryScorerWrapper MakeScorer(TWeightsPredictor predictor, RoleMappedData data) + => new BinaryScorerWrapper(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments()); + } + + public sealed class MySdcaMulticlass : TrainerBase, IPredictor> + { + private readonly SdcaMultiClassTrainer.Arguments _args; + + public MySdcaMulticlass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments args, string featureCol, string labelCol) + : base(env, new TrainerInfo(), featureCol, labelCol) + { + _args = args; + } + + protected override ScorerWrapper MakeScorer(IPredictor predictor, RoleMappedData data) => MakeScorerBasic(predictor, data); + + protected override IPredictor TrainCore(TrainContext context) => new SdcaMultiClassTrainer(_env, _args).Train(context); + } + + public sealed class MyAveragedPerceptron : TrainerBase, IPredictor> + { + private readonly AveragedPerceptronTrainer _trainer; + + public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol) + : base(env, new TrainerInfo(caching: false), featureCol, labelCol) + { + _trainer = new AveragedPerceptronTrainer(env, args); + } + + protected override IPredictor TrainCore(TrainContext trainContext) => _trainer.Train(trainContext); + + public ITransformer Train(IDataView trainData, IPredictor initialPredictor) + { + return TrainTransformer(trainData, initPredictor: initialPredictor); + } + + protected override BinaryScorerWrapper MakeScorer(IPredictor predictor, RoleMappedData data) + => new BinaryScorerWrapper(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments()); + } + + public sealed class MyPredictionEngine + where TSrc : class + where TDst : class, new() + { + private readonly PredictionEngine _engine; + + public MyPredictionEngine(IHostEnvironment env, ITransformer pipe) + { + IDataView dv = env.CreateDataView(new TSrc[0]); + _engine = env.CreatePredictionEngine(pipe.Transform(dv)); + } + + public TDst Predict(TSrc example) + { + return _engine.Predict(example); + } + } + + public sealed class MyBinaryClassifierEvaluator + { + private readonly IHostEnvironment _env; + private readonly BinaryClassifierEvaluator _evaluator; + + public MyBinaryClassifierEvaluator(IHostEnvironment env, BinaryClassifierEvaluator.Arguments args) + { + _env = env; + _evaluator = new BinaryClassifierEvaluator(env, args); + } + + public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn = DefaultColumnNames.Label, + string probabilityColumn = DefaultColumnNames.Probability) + { + var ci = EvaluateUtils.GetScoreColumnInfo(_env, data.Schema, null, DefaultColumnNames.Score, MetadataUtils.Const.ScoreColumnKind.BinaryClassification); + var map = new KeyValuePair[] + { + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probabilityColumn), + RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, ci.Name) + }; + var rmd = new RoleMappedData(data, labelColumn, DefaultColumnNames.Features, opt: true, custom: map); + + var metricsDict = _evaluator.Evaluate(rmd); + return BinaryClassificationMetrics.FromMetrics(_env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"]).Single(); + } + } + + public static class MyCrossValidation + { + public sealed class BinaryCrossValidationMetrics + { + public readonly ITransformer[] FoldModels; + public readonly BinaryClassificationMetrics[] FoldMetrics; + + public BinaryCrossValidationMetrics(ITransformer[] models, BinaryClassificationMetrics[] metrics) + { + FoldModels = models; + FoldMetrics = metrics; + } + } + + public sealed class BinaryCrossValidator + { + private readonly IHostEnvironment _env; + + public int NumFolds { get; set; } = 2; + + public string StratificationColumn { get; set; } + + public string LabelColumn { get; set; } = DefaultColumnNames.Label; + + public BinaryCrossValidator(IHostEnvironment env) + { + _env = env; + } + + public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator estimator) + { + var models = new ITransformer[NumFolds]; + var metrics = new BinaryClassificationMetrics[NumFolds]; + + if (StratificationColumn == null) + { + StratificationColumn = "StratificationColumn"; + var random = new GenerateNumberTransform(_env, trainData, StratificationColumn); + trainData = random; + } + else + throw new NotImplementedException(); + + var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments() { }); + + for (int fold = 0; fold < NumFolds; fold++) + { + var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments() + { + Column = StratificationColumn, + Min = (Double)fold / NumFolds, + Max = (Double)(fold + 1) / NumFolds, + Complement = true + }, trainData); + var testFilter = new RangeFilter(_env, new RangeFilter.Arguments() + { + Column = StratificationColumn, + Min = (Double)fold / NumFolds, + Max = (Double)(fold + 1) / NumFolds, + Complement = false + }, trainData); + + models[fold] = estimator.Fit(trainFilter); + var scoredTest = models[fold].Transform(testFilter); + metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability"); + } + + return new BinaryCrossValidationMetrics(models, metrics); + + } + } + } + + public class MyLambdaTransform : IEstimator + where TSrc : class, new() + where TDst : class, new() + { + private readonly IHostEnvironment _env; + private readonly Action _action; + + public MyLambdaTransform(IHostEnvironment env, Action action) + { + _env = env; + _action = action; + } + + public TransformWrapper Fit(IDataView input) + { + var xf = LambdaTransform.CreateMap(_env, input, _action); + var empty = new EmptyDataView(_env, input.Schema); + var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input); + return new TransformWrapper(_env, chunk); + } + + public SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + throw new NotImplementedException(); + } + } + + public sealed class MyOva : TrainerBase, OvaPredictor> + { + private readonly ITrainerEstimator, TScalarPredictor> _binaryEstimator; + + public MyOva(IHostEnvironment env, ITrainerEstimator, TScalarPredictor> estimator, + string featureColumn = DefaultColumnNames.Features, string labelColumn = DefaultColumnNames.Label) + : base(env, MakeTrainerInfo(estimator), featureColumn, labelColumn) + { + _binaryEstimator = estimator; + } + + private static TrainerInfo MakeTrainerInfo(ITrainerEstimator, TScalarPredictor> estimator) + => new TrainerInfo(estimator.TrainerInfo.NeedNormalization, estimator.TrainerInfo.NeedCalibration, false); + + protected override ScorerWrapper MakeScorer(OvaPredictor predictor, RoleMappedData data) + => MakeScorerBasic(predictor, data); + + protected override OvaPredictor TrainCore(TrainContext trainContext) + { + var trainRoles = trainContext.TrainingSet; + trainRoles.CheckMultiClassLabel(out var numClasses); + + var predictors = new IPredictorTransformer[numClasses]; + for (int iClass = 0; iClass < numClasses; iClass++) + { + var data = new LabelIndicatorTransform(_env, trainRoles.Data, iClass, "Label"); + predictors[iClass] = _binaryEstimator.Fit(data); + } + var prs = predictors.Select(x => x.InnerModel); + var finalPredictor = OvaPredictor.Create(_env.Register("ova"), prs.ToArray()); + return finalPredictor; + } + } + + public static class MyHelperExtensions + { + public static void SaveAsBinary(this IDataView data, IHostEnvironment env, Stream stream) + { + var saver = new BinarySaver(env, new BinarySaver.Arguments()); + using (var ch = env.Start("SaveData")) + DataSaverUtils.SaveDataView(ch, saver, data, stream); + } + + public static IDataView FitAndTransform(this IEstimator est, IDataView data) => est.Fit(data).Transform(data); + + public static IDataView FitAndRead(this IDataReaderEstimator> est, TSource source) + => est.Fit(source).Read(source); + } +} diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs index 00e4f9b703..b1be31e89d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs @@ -38,7 +38,7 @@ private TOut GetValue(Dictionary keyValues, string key) /// [Fact] - void IntrospectiveTraining() + public void IntrospectiveTraining() { var dataPath = GetDataPath(SentimentDataPath); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs index 811f857046..d8e2fa8079 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs @@ -1,10 +1,12 @@ -using Microsoft.ML.Runtime; -using Microsoft.ML.Runtime.Api; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.EntryPoints; using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.Learners; -using System.Linq; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api @@ -17,7 +19,7 @@ public partial class ApiScenariosTests /// If they specify a regression or multi-class classifier ideally that should be a compile error. /// [Fact] - void Metacomponents() + public void Metacomponents() { var dataPath = GetDataPath(IrisDataPath); using (var env = new TlcEnvironment()) @@ -28,7 +30,7 @@ void Metacomponents() var trainer = new Ova(env, new Ova.Arguments { PredictorType = ComponentFactoryUtils.CreateFromFunction( - (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) + e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments())) }); IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat; @@ -36,21 +38,7 @@ void Metacomponents() // Auto-normalization. NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer); - var predictor = trainer.Train(new Runtime.TrainContext(trainRoles)); - - var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features"); - IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema); - - var keyToValue = new KeyToValueTransform(env, scorer, "PredictedLabel"); - var model = env.CreatePredictionEngine(keyToValue); - - var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath)); - var testData = testLoader.AsEnumerable(env, false); - foreach (var input in testData.Take(20)) - { - var prediction = model.Predict(input); - Assert.True(prediction.PredictedLabel == input.Label); - } + var predictor = trainer.Train(new TrainContext(trainRoles)); } } } diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs index 23fab1c1d6..b561f6b3d2 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Runtime.Api; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Learners; using System.Linq; diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs index 97bb01158e..d0c5e91787 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs @@ -1,4 +1,8 @@ -using Microsoft.ML.Runtime.Data; +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Data; using Xunit; namespace Microsoft.ML.Tests.Scenarios.Api