diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs
new file mode 100644
index 0000000000..2d5200ed5e
--- /dev/null
+++ b/src/Microsoft.ML.Core/Data/IEstimator.cs
@@ -0,0 +1,190 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace Microsoft.ML.Core.Data
+{
+ ///
+ /// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
+ /// This is more relaxed than the proper , since it's only a subset of the columns,
+ /// and also since it doesn't specify exact 's for vectors and keys.
+ ///
+ public sealed class SchemaShape
+ {
+ public readonly Column[] Columns;
+
+ public sealed class Column
+ {
+ public enum VectorKind
+ {
+ Scalar,
+ Vector,
+ VariableVector
+ }
+
+ public readonly string Name;
+ public readonly VectorKind Kind;
+ public readonly DataKind ItemKind;
+ public readonly bool IsKey;
+ public readonly string[] MetadataKinds;
+
+ public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds)
+ {
+ Contracts.CheckNonEmpty(name, nameof(name));
+ Contracts.CheckValue(metadataKinds, nameof(metadataKinds));
+
+ Name = name;
+ Kind = vecKind;
+ ItemKind = itemKind;
+ IsKey = isKey;
+ MetadataKinds = metadataKinds;
+ }
+ }
+
+ public SchemaShape(Column[] columns)
+ {
+ Contracts.CheckValue(columns, nameof(columns));
+ Columns = columns;
+ }
+
+ ///
+ /// Create a schema shape out of the fully defined schema.
+ ///
+ public static SchemaShape Create(ISchema schema)
+ {
+ Contracts.CheckValue(schema, nameof(schema));
+ var cols = new List();
+
+ for (int iCol = 0; iCol < schema.ColumnCount; iCol++)
+ {
+ if (!schema.IsHidden(iCol))
+ {
+ Column.VectorKind vecKind;
+ var type = schema.GetColumnType(iCol);
+ if (type.IsKnownSizeVector)
+ vecKind = Column.VectorKind.Vector;
+ else if (type.IsVector)
+ vecKind = Column.VectorKind.VariableVector;
+ else
+ vecKind = Column.VectorKind.Scalar;
+
+ var kind = type.ItemType.RawKind;
+ var isKey = type.ItemType.IsKey;
+
+ var metadataNames = schema.GetMetadataTypes(iCol)
+ .Select(kvp => kvp.Key)
+ .ToArray();
+ cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
+ }
+ }
+ return new SchemaShape(cols.ToArray());
+ }
+
+ ///
+ /// Returns the column with a specified , and null if there is no such column.
+ ///
+ public Column FindColumn(string name)
+ {
+ Contracts.CheckValue(name, nameof(name));
+ return Columns.FirstOrDefault(x => x.Name == name);
+ }
+
+ // REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
+ // as an input to another schema shape. I started writing, but realized that there's more than one way to check for
+ // the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
+ }
+
+ ///
+ /// Exception class for schema validation errors.
+ ///
+ public class SchemaException : Exception
+ {
+ }
+
+ ///
+ /// The 'data reader' takes a certain kind of input and turns it into an .
+ ///
+ /// The type of input the reader takes.
+ public interface IDataReader
+ {
+ ///
+ /// Produce the data view from the specified input.
+ /// Note that 's are lazy, so no actual reading happens here, just schema validation.
+ ///
+ IDataView Read(TSource input);
+
+ ///
+ /// The output schema of the reader.
+ ///
+ ISchema GetOutputSchema();
+ }
+
+ ///
+ /// Sometimes we need to 'fit' an .
+ /// A DataReader estimator is the object that does it.
+ ///
+ public interface IDataReaderEstimator
+ where TReader : IDataReader
+ {
+ ///
+ /// Train and return a data reader.
+ ///
+ /// REVIEW: you could consider the transformer to take a different , but we don't have such components
+ /// yet, so why complicate matters?
+ ///
+ TReader Fit(TSource input);
+
+ ///
+ /// The 'promise' of the output schema.
+ /// It will be used for schema propagation.
+ ///
+ SchemaShape GetOutputSchema();
+ }
+
+ ///
+ /// The transformer is a component that transforms data.
+ /// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'.
+ ///
+ public interface ITransformer
+ {
+ ///
+ /// Schema propagation for transformers.
+ /// Returns the output schema of the data, if the input schema is like the one provided.
+ /// Throws iff the input schema is not valid for the transformer.
+ ///
+ ISchema GetOutputSchema(ISchema inputSchema);
+
+ ///
+ /// Take the data in, make transformations, output the data.
+ /// Note that 's are lazy, so no actual transformations happen here, just schema validation.
+ ///
+ IDataView Transform(IDataView input);
+ }
+
+ ///
+ /// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
+ /// a transformer.
+ /// It also provides the 'schema propagation' like transformers do, but over instead of .
+ ///
+ public interface IEstimator
+ where TTransformer : ITransformer
+ {
+ ///
+ /// Train and return a transformer.
+ ///
+ TTransformer Fit(IDataView input);
+
+ ///
+ /// Schema propagation for estimators.
+ /// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
+ /// Throws iff the input schema is not valid for the estimator.
+ ///
+ SchemaShape GetOutputSchema(SchemaShape inputSchema);
+ }
+}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
new file mode 100644
index 0000000000..db4f207fa3
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
@@ -0,0 +1,111 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime.Model;
+using System.IO;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// This class represents a data reader that applies a transformer chain after reading.
+ /// It also has methods to save itself to a repository.
+ ///
+ public sealed class CompositeDataReader : IDataReader
+ where TLastTransformer : class, ITransformer
+ {
+ ///
+ /// The underlying data reader.
+ ///
+ public readonly IDataReader Reader;
+ ///
+ /// The chain of transformers (possibly empty) that are applied to data upon reading.
+ ///
+ public readonly TransformerChain Transformer;
+
+ public CompositeDataReader(IDataReader reader, TransformerChain transformerChain = null)
+ {
+ Contracts.CheckValue(reader, nameof(reader));
+ Contracts.CheckValueOrNull(transformerChain);
+
+ Reader = reader;
+ Transformer = transformerChain ?? new TransformerChain();
+ }
+
+ public IDataView Read(TSource input)
+ {
+ var idv = Reader.Read(input);
+ idv = Transformer.Transform(idv);
+ return idv;
+ }
+
+ public ISchema GetOutputSchema()
+ {
+ var s = Reader.GetOutputSchema();
+ return Transformer.GetOutputSchema(s);
+ }
+
+ ///
+ /// Append a new transformer to the end.
+ ///
+ /// The new composite data reader
+ public CompositeDataReader AppendTransformer(TNewLast transformer)
+ where TNewLast : class, ITransformer
+ {
+ Contracts.CheckValue(transformer, nameof(transformer));
+
+ return new CompositeDataReader(Reader, Transformer.Append(transformer));
+ }
+
+ ///
+ /// Save the contents to a stream, as a "model file".
+ ///
+ public void SaveTo(IHostEnvironment env, Stream outputStream)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(outputStream, nameof(outputStream));
+
+ env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
+ using (var ch = env.Start("Saving pipeline"))
+ {
+ using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
+ {
+ ch.Trace("Saving data reader");
+ ModelSaveContext.SaveModel(rep, Reader, "Reader");
+
+ ch.Trace("Saving transformer chain");
+ ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
+ rep.Commit();
+ }
+ }
+ }
+ }
+
+ ///
+ /// Utility class to facilitate loading from a stream.
+ ///
+ public static class CompositeDataReader
+ {
+ ///
+ /// Load the pipeline from stream.
+ ///
+ public static CompositeDataReader LoadFrom(IHostEnvironment env, Stream stream)
+ {
+ Contracts.CheckValue(env, nameof(env));
+ env.CheckValue(stream, nameof(stream));
+
+ env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
+ using (var rep = RepositoryReader.Open(stream, env))
+ using (var ch = env.Start("Loading pipeline"))
+ {
+ ch.Trace("Loading data reader");
+ ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var reader, rep, "Reader");
+
+ ch.Trace("Loader transformer chain");
+ ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
+ return new CompositeDataReader(reader, transformerChain);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs
new file mode 100644
index 0000000000..49f7f8b99d
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs
@@ -0,0 +1,59 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// An estimator class for composite data reader.
+ /// It can be used to build a 'trainable smart data reader', although this pattern is not very common.
+ ///
+ public sealed class CompositeReaderEstimator : IDataReaderEstimator>
+ where TLastTransformer : class, ITransformer
+ {
+ private readonly IDataReaderEstimator> _start;
+ private readonly EstimatorChain _estimatorChain;
+
+ public CompositeReaderEstimator(IDataReaderEstimator> start, EstimatorChain estimatorChain = null)
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValueOrNull(estimatorChain);
+
+ _start = start;
+ _estimatorChain = estimatorChain ?? new EstimatorChain();
+
+ // REVIEW: enforce that estimator chain can read the reader's schema.
+ // Right now it throws.
+ // GetOutputSchema();
+ }
+
+ public CompositeDataReader Fit(TSource input)
+ {
+ var start = _start.Fit(input);
+ var idv = start.Read(input);
+
+ var xfChain = _estimatorChain.Fit(idv);
+ return new CompositeDataReader(start, xfChain);
+ }
+
+ public SchemaShape GetOutputSchema()
+ {
+ var shape = _start.GetOutputSchema();
+ return _estimatorChain.GetOutputSchema(shape);
+ }
+
+ ///
+ /// Append another estimator to the end.
+ ///
+ public CompositeReaderEstimator Append(IEstimator estimator)
+ where TNewTrans : class, ITransformer
+ {
+ Contracts.CheckValue(estimator, nameof(estimator));
+
+ return new CompositeReaderEstimator(_start, _estimatorChain.Append(estimator));
+ }
+ }
+
+}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
new file mode 100644
index 0000000000..727dfa9e0b
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorChain.cs
@@ -0,0 +1,78 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using System.Linq;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// Represents a chain (potentially empty) of estimators that end with a .
+ /// If the chain is empty, is always .
+ ///
+ public sealed class EstimatorChain : IEstimator>
+ where TLastTransformer : class, ITransformer
+ {
+ private readonly TransformerScope[] _scopes;
+ private readonly IEstimator[] _estimators;
+ public readonly IEstimator LastEstimator;
+
+ private EstimatorChain(IEstimator[] estimators, TransformerScope[] scopes)
+ {
+ Contracts.AssertValueOrNull(estimators);
+ Contracts.AssertValueOrNull(scopes);
+ Contracts.Assert(Utils.Size(estimators) == Utils.Size(scopes));
+
+ _estimators = estimators ?? new IEstimator[0];
+ _scopes = scopes ?? new TransformerScope[0];
+ LastEstimator = estimators.LastOrDefault() as IEstimator;
+
+ Contracts.Assert((_estimators.Length > 0) == (LastEstimator != null));
+ }
+
+ ///
+ /// Create an empty estimator chain.
+ ///
+ public EstimatorChain()
+ {
+ _estimators = new IEstimator[0];
+ _scopes = new TransformerScope[0];
+ LastEstimator = null;
+ }
+
+ public TransformerChain Fit(IDataView input)
+ {
+ // REVIEW: before fitting, run schema propagation.
+ // Currently, it throws.
+ // GetOutputSchema(SchemaShape.Create(input.Schema);
+
+ IDataView current = input;
+ var xfs = new ITransformer[_estimators.Length];
+ for (int i = 0; i < _estimators.Length; i++)
+ {
+ var est = _estimators[i];
+ xfs[i] = est.Fit(current);
+ current = xfs[i].Transform(current);
+ }
+
+ return new TransformerChain(xfs, _scopes);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ var s = inputSchema;
+ foreach (var est in _estimators)
+ s = est.GetOutputSchema(s);
+ return s;
+ }
+
+ public EstimatorChain Append(IEstimator estimator, TransformerScope scope = TransformerScope.Everything)
+ where TNewTrans : class, ITransformer
+ {
+ Contracts.CheckValue(estimator, nameof(estimator));
+ return new EstimatorChain(_estimators.Append(estimator).ToArray(), _scopes.Append(scope).ToArray());
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
new file mode 100644
index 0000000000..d862c69028
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/EstimatorExtensions.cs
@@ -0,0 +1,65 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// Extension methods that allow chaining estimator and transformer pipes together.
+ ///
+ public static class LearningPipelineExtensions
+ {
+ ///
+ /// Create a composite reader estimator by appending an estimator to a reader estimator.
+ ///
+ public static CompositeReaderEstimator Append(
+ this IDataReaderEstimator> start, IEstimator estimator)
+ where TTrans : class, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValue(estimator, nameof(estimator));
+
+ return new CompositeReaderEstimator(start).Append(estimator);
+ }
+
+ ///
+ /// Create an estimator chain by appending an estimator to an estimator.
+ ///
+ public static EstimatorChain Append(
+ this IEstimator start, IEstimator estimator,
+ TransformerScope scope = TransformerScope.Everything)
+ where TTrans : class, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValue(estimator, nameof(estimator));
+
+ return new EstimatorChain().Append(start).Append(estimator, scope);
+ }
+
+ ///
+ /// Create a composite reader by appending a transformer to a data reader.
+ ///
+ public static CompositeDataReader Append(this IDataReader reader, TTrans transformer)
+ where TTrans : class, ITransformer
+ {
+ Contracts.CheckValue(reader, nameof(reader));
+ Contracts.CheckValue(transformer, nameof(transformer));
+
+ return new CompositeDataReader(reader).AppendTransformer(transformer);
+ }
+
+ ///
+ /// Create a transformer chain by appending a transformer to a transformer.
+ ///
+ public static TransformerChain Append(this ITransformer start, TTrans transformer)
+ where TTrans : class, ITransformer
+ {
+ Contracts.CheckValue(start, nameof(start));
+ Contracts.CheckValue(transformer, nameof(transformer));
+
+ return new TransformerChain(start, transformer);
+ }
+ }
+}
diff --git a/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
new file mode 100644
index 0000000000..d5f246689a
--- /dev/null
+++ b/src/Microsoft.ML.Data/DataLoadSave/TransformerChain.cs
@@ -0,0 +1,227 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Internal.Utilities;
+using Microsoft.ML.Runtime.Model;
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+[assembly: LoadableClass(typeof(TransformerChain), typeof(TransformerChain), null, typeof(SignatureLoadModel),
+ "Transformer chain", TransformerChain.LoaderSignature)]
+
+namespace Microsoft.ML.Runtime.Data
+{
+ ///
+ /// This enum allows for 'tagging' the estimators (and subsequently transformers) in the chain to be used
+ /// 'only for training', 'for training and evaluation' etc.
+ /// Most notable example is, transformations over the label column should not be used for scoring, so the scope
+ /// should be or .
+ ///
+ [Flags]
+ public enum TransformerScope
+ {
+ None = 0,
+ Training = 1 << 0,
+ Testing = 1 << 1,
+ Scoring = 1 << 2,
+ TrainTest = Training | Testing,
+ Everything = Training | Testing | Scoring
+ }
+
+ ///
+ /// A chain of transformers (possibly empty) that end with a .
+ /// For an empty chain, is always .
+ ///
+ public sealed class TransformerChain : ITransformer, ICanSaveModel, IEnumerable
+ where TLastTransformer : class, ITransformer
+ {
+ private readonly ITransformer[] _transformers;
+ private readonly TransformerScope[] _scopes;
+ public readonly TLastTransformer LastTransformer;
+
+ private const string TransformDirTemplate = "Transform_{0:000}";
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "XF CHAIN",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: TransformerChain.LoaderSignature);
+ }
+
+ ///
+ /// Create a transformer chain by specifying transformers and their scopes.
+ ///
+ /// Transformers to be chained.
+ /// Transformer scopes, parallel to .
+ public TransformerChain(IEnumerable transformers, IEnumerable scopes)
+ {
+ Contracts.CheckValueOrNull(transformers);
+ Contracts.CheckValueOrNull(scopes);
+
+ _transformers = transformers?.ToArray() ?? new ITransformer[0];
+ _scopes = scopes?.ToArray() ?? new TransformerScope[0];
+ LastTransformer = transformers.LastOrDefault() as TLastTransformer;
+
+ Contracts.Check((_transformers.Length > 0) == (LastTransformer != null));
+ Contracts.Check(_transformers.Length == _scopes.Length);
+ }
+
+ ///
+ /// Create a transformer chain by specifying all the transformers. The scopes are assumed to be
+ /// .
+ ///
+ ///
+ public TransformerChain(params ITransformer[] transformers)
+ {
+ Contracts.CheckValueOrNull(transformers);
+
+ if (Utils.Size(transformers) == 0)
+ {
+ _transformers = new ITransformer[0];
+ _scopes = new TransformerScope[0];
+ LastTransformer = null;
+ }
+ else
+ {
+ _transformers = transformers.ToArray();
+ _scopes = transformers.Select(x => TransformerScope.Everything).ToArray();
+ LastTransformer = transformers.Last() as TLastTransformer;
+ Contracts.Check(LastTransformer != null);
+ }
+ }
+
+ public ISchema GetOutputSchema(ISchema inputSchema)
+ {
+ Contracts.CheckValue(inputSchema, nameof(inputSchema));
+
+ var s = inputSchema;
+ foreach (var xf in _transformers)
+ s = xf.GetOutputSchema(s);
+ return s;
+ }
+
+ public IDataView Transform(IDataView input)
+ {
+ Contracts.CheckValue(input, nameof(input));
+
+ // Trigger schema propagation prior to transforming.
+ // REVIEW: does this actually constitute 'early warning', given that Transform call is lazy anyway?
+ GetOutputSchema(input.Schema);
+
+ var dv = input;
+ foreach (var xf in _transformers)
+ dv = xf.Transform(dv);
+ return dv;
+ }
+
+ public TransformerChain GetModelFor(TransformerScope scopeFilter)
+ {
+ var xfs = new List();
+ var scopes = new List();
+ for (int i = 0; i < _transformers.Length; i++)
+ {
+ if ((_scopes[i] & scopeFilter) != TransformerScope.None)
+ {
+ xfs.Add(_transformers[i]);
+ scopes.Add(_scopes[i]);
+ }
+ }
+ return new TransformerChain(xfs.ToArray(), scopes.ToArray());
+ }
+
+ public TransformerChain Append(TNewLast transformer, TransformerScope scope = TransformerScope.Everything)
+ where TNewLast : class, ITransformer
+ {
+ Contracts.CheckValue(transformer, nameof(transformer));
+ return new TransformerChain(_transformers.Append(transformer).ToArray(), _scopes.Append(scope).ToArray());
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ ctx.Writer.Write(_transformers.Length);
+
+ for (int i = 0; i < _transformers.Length; i++)
+ {
+ ctx.Writer.Write((int)_scopes[i]);
+ var dirName = string.Format(TransformDirTemplate, i);
+ ctx.SaveModel(_transformers[i], dirName);
+ }
+ }
+
+ ///
+ /// The loading constructor of transformer chain. Reverse of .
+ ///
+ internal TransformerChain(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ int len = ctx.Reader.ReadInt32();
+ _transformers = new ITransformer[len];
+ _scopes = new TransformerScope[len];
+ for (int i = 0; i < len; i++)
+ {
+ _scopes[i] = (TransformerScope)(ctx.Reader.ReadInt32());
+ var dirName = string.Format(TransformDirTemplate, i);
+ ctx.LoadModel(env, out _transformers[i], dirName);
+ }
+ if (len > 0)
+ LastTransformer = _transformers[len - 1] as TLastTransformer;
+ else
+ LastTransformer = null;
+ }
+
+ public void SaveTo(IHostEnvironment env, Stream outputStream)
+ {
+ using (var ch = env.Start("Saving pipeline"))
+ {
+ using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
+ {
+ ch.Trace("Saving transformer chain");
+ ModelSaveContext.SaveModel(rep, this, TransformerChain.LoaderSignature);
+ rep.Commit();
+ }
+ }
+ }
+
+ public IEnumerator GetEnumerator() => ((IEnumerable)_transformers).GetEnumerator();
+
+ IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
+ }
+
+ ///
+ /// Saving/loading routines for transformer chains.
+ ///
+ public static class TransformerChain
+ {
+ public const string LoaderSignature = "TransformerChain";
+
+ public static TransformerChain Create(IHostEnvironment env, ModelLoadContext ctx)
+ => new TransformerChain(env, ctx);
+
+ ///
+ /// Save any transformer to a stream by wrapping it into a transformer chain.
+ ///
+ public static void SaveTo(this ITransformer transformer, IHostEnvironment env, Stream outputStream)
+ => new TransformerChain(transformer).SaveTo(env, outputStream);
+
+ public static TransformerChain LoadFrom(IHostEnvironment env, Stream stream)
+ {
+ using (var rep = RepositoryReader.Open(stream, env))
+ {
+ ModelLoadContext.LoadModel, SignatureLoadModel>(env, out var transformerChain, rep, LoaderSignature);
+ return transformerChain;
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
index 4810872906..be16385901 100644
--- a/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
+++ b/test/Microsoft.ML.Core.Tests/Microsoft.ML.Core.Tests.csproj
@@ -19,6 +19,7 @@
+
diff --git a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
index 2d8013f545..d8d5cd3d5d 100644
--- a/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
+++ b/test/Microsoft.ML.Tests/Microsoft.ML.Tests.csproj
@@ -1,4 +1,9 @@
+
+
+
+
+
@@ -22,4 +27,8 @@
+
+
+
+
\ No newline at end of file
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs b/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs
new file mode 100644
index 0000000000..9ba601ad82
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/AspirationalExamples.cs
@@ -0,0 +1,64 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+// This file contains code examples that currently do not even compile.
+// They serve as the reference point of the 'desired user-facing API' for the future work.
+
+using System;
+using System.Collections.Generic;
+using System.Text;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public class AspirationalExamples
+ {
+ public class IrisPrediction
+ {
+ public string PredictedLabel;
+ }
+
+ public class IrisExample
+ {
+ public float SepalWidth { get; set; }
+ public float SepalLength { get; set; }
+ public float PetalWidth { get; set; }
+ public float PetalLength { get; set; }
+ }
+
+ public void FirstExperienceWithML()
+ {
+ // Load the data into the system.
+ string dataPath = "iris-data.txt";
+ var data = TextReader.FitAndRead(env, dataPath, row => (
+ Label: row.ReadString(0),
+ SepalWidth: row.ReadFloat(1),
+ SepalLength: row.ReadFloat(2),
+ PetalWidth: row.ReadFloat(3),
+ PetalLength: row.ReadFloat(4)));
+
+
+ var preprocess = data.Schema.MakeEstimator(row => (
+ // Convert string label to key.
+ Label: row.Label.DictionarizeLabel(),
+ // Concatenate all features into a vector.
+ Features: row.SepalWidth.ConcatWith(row.SepalLength, row.PetalWidth, row.PetalLength)));
+
+ var pipeline = preprocess
+ // Append the trainer to the training pipeline.
+ .AppendEstimator(row => row.Label.PredictWithSdca(row.Features))
+ .AppendEstimator(row => row.PredictedLabel.KeyToValue());
+
+ // Train the model and make some predictions.
+ var model = pipeline.Fit(data);
+
+ IrisPrediction prediction = model.Predict(new IrisExample
+ {
+ SepalWidth = 3.3f,
+ SepalLength = 1.6f,
+ PetalWidth = 0.2f,
+ PetalLength = 5.1f
+ });
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs
new file mode 100644
index 0000000000..fb6c56674b
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs
@@ -0,0 +1,49 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Cross-validation: Have a mechanism to do cross validation, that is, you come up with
+ /// a data source (optionally with stratification column), come up with an instantiable transform
+ /// and trainer pipeline, and it will handle (1) splitting up the data, (2) training the separate
+ /// pipelines on in-fold data, (3) scoring on the out-fold data, (4) returning the set of
+ /// evaluations and optionally trained pipes. (People always want metrics out of xfold,
+ /// they sometimes want the actual models too.)
+ ///
+ [Fact]
+ void New_CrossValidation()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+
+ var data = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .FitAndRead(new MultiFileSource(dataPath));
+ // Pipeline.
+ var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs())
+ .Append(new MySdca(env, new LinearClassificationTrainer.Arguments
+ {
+ NumThreads = 1,
+ ConvergenceTolerance = 1f
+ }, "Features", "Label"));
+
+ var cv = new MyCrossValidation.BinaryCrossValidator(env)
+ {
+ NumFolds = 2
+ };
+
+ var cvResult = cv.CrossValidate(data, pipeline);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs
new file mode 100644
index 0000000000..eb8a6991c5
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/DecomposableTrainAndPredict.cs
@@ -0,0 +1,50 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using System.Linq;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Decomposable train and predict: Train on Iris multiclass problem, which will require
+ /// a transform on labels. Be able to reconstitute the pipeline for a prediction only task,
+ /// which will essentially "drop" the transform over labels, while retaining the property
+ /// that the predicted label for this has a key-type, the probability outputs for the classes
+ /// have the class labels as slot names, etc. This should be do-able without ugly compromises like,
+ /// say, injecting a dummy label.
+ ///
+ [Fact]
+ void New_DecomposableTrainAndPredict()
+ {
+ var dataPath = GetDataPath(IrisDataPath);
+ using (var env = new TlcEnvironment())
+ {
+ var data = new MyTextLoader(env, MakeIrisTextLoaderArgs())
+ .FitAndRead(new MultiFileSource(dataPath));
+
+ var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
+ .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest)
+ .Append(new MySdcaMulticlass(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"))
+ .Append(new MyKeyToValueTransform(env, "PredictedLabel"));
+
+ var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring);
+ var engine = new MyPredictionEngine(env, model);
+
+ var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
+ var testData = testLoader.AsEnumerable(env, false);
+ foreach (var input in testData.Take(20))
+ {
+ var prediction = engine.Predict(input);
+ Assert.True(prediction.PredictedLabel == input.Label);
+ }
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs
new file mode 100644
index 0000000000..dcb7f515c2
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs
@@ -0,0 +1,42 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Evaluation: Similar to the simple train scenario, except instead of having some
+ /// predictive structure, be able to score another "test" data file, run the result
+ /// through an evaluator and get metrics like AUC, accuracy, PR curves, and whatnot.
+ /// Getting metrics out of this shoudl be as straightforward and unannoying as possible.
+ ///
+ [Fact]
+ public void New_Evaluation()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()))
+ .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
+
+ // Train.
+ var model = pipeline.Fit(new MultiFileSource(dataPath));
+
+ // Evaluate on the test set.
+ var dataEval = model.Read(new MultiFileSource(testDataPath));
+ var evaluator = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments() { });
+ var metrics = evaluator.Evaluate(dataEval);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs
new file mode 100644
index 0000000000..53cfba2fbf
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Extensibility.cs
@@ -0,0 +1,58 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using System;
+using System.Linq;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Extensibility: We can't possibly write every conceivable transform and should not try.
+ /// It should somehow be possible for a user to inject custom code to, say, transform data.
+ /// This might have a much steeper learning curve than the other usages (which merely involve
+ /// usage of already established components), but should still be possible.
+ ///
+ [Fact]
+ void New_Extensibility()
+ {
+ var dataPath = GetDataPath(IrisDataPath);
+ using (var env = new TlcEnvironment())
+ {
+ var data = new MyTextLoader(env, MakeIrisTextLoaderArgs())
+ .FitAndRead(new MultiFileSource(dataPath));
+
+ Action action = (i, j) =>
+ {
+ j.Label = i.Label;
+ j.PetalLength = i.SepalLength > 3 ? i.PetalLength : i.SepalLength;
+ j.PetalWidth = i.PetalWidth;
+ j.SepalLength = i.SepalLength;
+ j.SepalWidth = i.SepalWidth;
+ };
+ var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
+ .Append(new MyLambdaTransform(env, action))
+ .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest)
+ .Append(new MySdcaMulticlass(env, new SdcaMultiClassTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label"))
+ .Append(new MyKeyToValueTransform(env, "PredictedLabel"));
+
+ var model = pipeline.Fit(data).GetModelFor(TransformerScope.Scoring);
+ var engine = new MyPredictionEngine(env, model);
+
+ var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
+ var testData = testLoader.AsEnumerable(env, false);
+ foreach (var input in testData.Take(20))
+ {
+ var prediction = engine.Predict(input);
+ Assert.True(prediction.PredictedLabel == input.Label);
+ }
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs
new file mode 100644
index 0000000000..6e355f8db2
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs
@@ -0,0 +1,47 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Data.IO;
+using Microsoft.ML.Runtime.Learners;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// File-based saving of data: Come up with transform pipeline. Transform training and
+ /// test data, and save the featurized data to some file, using the .idv format.
+ /// Train and evaluate multiple models over that pre-featurized data. (Useful for
+ /// sweeping scenarios, where you are training many times on the same data,
+ /// and don't necessarily want to transform it every single time.)
+ ///
+ [Fact]
+ void New_FileBasedSavingOfData()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()));
+
+ var trainData = pipeline.Fit(new MultiFileSource(dataPath)).Read(new MultiFileSource(dataPath));
+
+ using (var file = env.CreateOutputFile("i.idv"))
+ trainData.SaveAsBinary(env, file.CreateWriteStream());
+
+ var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label");
+ var loadedTrainData = new BinaryLoader(env, new BinaryLoader.Arguments(), new MultiFileSource("i.idv"));
+
+ // Train.
+ var model = trainer.Train(loadedTrainData);
+ DeleteOutputPath("i.idv");
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs
new file mode 100644
index 0000000000..ae8f36183d
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs
@@ -0,0 +1,55 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.FastTree;
+using Microsoft.ML.Runtime.Internal.Calibration;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.Runtime.TextAnalytics;
+using System.Collections.Generic;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Introspective training: Models that produce outputs and are otherwise black boxes are of limited use;
+ /// it is also necessary often to understand at least to some degree what was learnt. To outline critical
+ /// scenarios that have come up multiple times:
+ /// *) When I train a linear model, I should be able to inspect coefficients.
+ /// *) The tree ensemble learners, I should be able to inspect the trees.
+ /// *) The LDA transform, I should be able to inspect the topics.
+ /// I view it as essential from a usability perspective that this be discoverable to someone without
+ /// having to read documentation.E.g.: if I have var lda = new LdaTransform().Fit(data)(I don't insist on that
+ /// exact signature, just giving the idea), then if I were to type lda.
+ /// In Visual Studio, one of the auto-complete targets should be something like GetTopics.
+ ///
+
+ [Fact]
+ public void New_IntrospectiveTraining()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ var data = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .FitAndRead(new MultiFileSource(dataPath));
+
+ var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs())
+ .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
+
+ // Train.
+ var model = pipeline.Fit(data);
+
+ // Get feature weights.
+ VBuffer weights = default;
+ model.LastTransformer.InnerModel.GetFeatureWeights(ref weights);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs
new file mode 100644
index 0000000000..e6ed7469bb
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Metacomponents.cs
@@ -0,0 +1,40 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using System.Linq;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Meta-components: Meta-components (e.g., components that themselves instantiate components) should not be booby-trapped.
+ /// When specifying what trainer OVA should use, a user will be able to specify any binary classifier.
+ /// If they specify a regression or multi-class classifier ideally that should be a compile error.
+ ///
+ [Fact]
+ public void New_Metacomponents()
+ {
+ var dataPath = GetDataPath(IrisDataPath);
+ using (var env = new TlcEnvironment())
+ {
+ var data = new MyTextLoader(env, MakeIrisTextLoaderArgs())
+ .FitAndRead(new MultiFileSource(dataPath));
+
+ var sdcaTrainer = new MySdca(env, new LinearClassificationTrainer.Arguments { MaxIterations = 100, Shuffle = true, NumThreads = 1 }, "Features", "Label");
+ var pipeline = new MyConcatTransform(env, "Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth")
+ .Append(new MyTermTransform(env, "Label"), TransformerScope.TrainTest)
+ .Append(new MyOva(env, sdcaTrainer))
+ .Append(new MyKeyToValueTransform(env, "PredictedLabel"));
+
+ var model = pipeline.Fit(data);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs
new file mode 100644
index 0000000000..e8386d166e
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs
@@ -0,0 +1,57 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using System.Threading.Tasks;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Multi-threaded prediction. A twist on "Simple train and predict", where we account that
+ /// multiple threads may want predictions at the same time. Because we deliberately do not
+ /// reallocate internal memory buffers on every single prediction, the PredictionEngine
+ /// (or its estimator/transformer based successor) is, like most stateful .NET objects,
+ /// fundamentally not thread safe. This is deliberate and as designed. However, some mechanism
+ /// to enable multi-threaded scenarios (e.g., a web server servicing requests) should be possible
+ /// and performant in the new API.
+ ///
+ [Fact]
+ void New_MultithreadedPrediction()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()))
+ .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
+
+ // Train.
+ var model = pipeline.Fit(new MultiFileSource(dataPath));
+
+ // Create prediction engine and test predictions.
+ var engine = new MyPredictionEngine(env, model.Transformer);
+
+ // Take a couple examples out of the test data and run predictions on top.
+ var testData = model.Reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath)))
+ .AsEnumerable(env, false);
+
+ Parallel.ForEach(testData, (input) =>
+ {
+ lock (engine)
+ {
+ var prediction = engine.Predict(input);
+ }
+ });
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs
new file mode 100644
index 0000000000..12a6a2db20
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs
@@ -0,0 +1,52 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Models;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Reconfigurable predictions: The following should be possible: A user trains a binary classifier,
+ /// and through the test evaluator gets a PR curve, the based on the PR curve picks a new threshold
+ /// and configures the scorer (or more precisely instantiates a new scorer over the same predictor)
+ /// with some threshold derived from that.
+ ///
+ [Fact]
+ public void New_ReconfigurablePrediction()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ var dataReader = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Fit(new MultiFileSource(dataPath));
+
+ var data = dataReader.Read(new MultiFileSource(dataPath));
+ var testData = dataReader.Read(new MultiFileSource(testDataPath));
+
+ // Pipeline.
+ var pipeline = new MyTextTransform(env, MakeSentimentTextTransformArgs())
+ .Fit(data);
+
+ var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label");
+ var trainData = pipeline.Transform(data);
+ var model = trainer.Fit(trainData);
+
+ var scoredTest = model.Transform(pipeline.Transform(testData));
+ var metrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments()).Evaluate(scoredTest, "Label", "Probability");
+
+ var newModel = model.Clone(new BinaryClassifierScorer.Arguments { Threshold = 0.01f, ThresholdColumn = DefaultColumnNames.Probability });
+ var newScoredTest = newModel.Transform(pipeline.Transform(testData));
+ var newMetrics = new MyBinaryClassifierEvaluator(env, new BinaryClassifierEvaluator.Arguments { Threshold = 0.01f, UseRawScoreThreshold = false }).Evaluate(newScoredTest, "Label", "Probability");
+ }
+
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs
new file mode 100644
index 0000000000..76c12e6068
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs
@@ -0,0 +1,53 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Xunit;
+using System.Linq;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Start with a dataset in a text file. Run text featurization on text values.
+ /// Train a linear model over that. (I am thinking sentiment classification.)
+ /// Out of the result, produce some structure over which you can get predictions programmatically
+ /// (e.g., the prediction does not happen over a file as it did during training).
+ ///
+ [Fact]
+ public void New_SimpleTrainAndPredict()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()))
+ .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
+
+ // Train.
+ var model = pipeline.Fit(new MultiFileSource(dataPath));
+
+ // Create prediction engine and test predictions.
+ var engine = new MyPredictionEngine(env, model.Transformer);
+
+ // Take a couple examples out of the test data and run predictions on top.
+ var testData = model.Reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath)))
+ .AsEnumerable(env, false);
+ foreach (var input in testData.Take(5))
+ {
+ var prediction = engine.Predict(input);
+ // Verify that predictions match and scores are separated from zero.
+ Assert.Equal(input.Sentiment, prediction.Sentiment);
+ Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
+ }
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs
new file mode 100644
index 0000000000..f5b5b98c90
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs
@@ -0,0 +1,65 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using System.Linq;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Train, save/load model, predict:
+ /// Serve the scenario where training and prediction happen in different processes (or even different machines).
+ /// The actual test will not run in different processes, but will simulate the idea that the
+ /// "communication pipe" is just a serialized model of some form.
+ ///
+ [Fact]
+ public void New_TrainSaveModelAndPredict()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var testDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()))
+ .Append(new MySdca(env, new LinearClassificationTrainer.Arguments { NumThreads = 1 }, "Features", "Label"));
+
+ // Train.
+ var model = pipeline.Fit(new MultiFileSource(dataPath));
+
+ ITransformer loadedModel;
+ using (var file = env.CreateTempFile())
+ {
+ // Save model.
+ using (var fs = file.CreateWriteStream())
+ model.Transformer.SaveTo(env, fs);
+
+ // Load model.
+ loadedModel = TransformerChain.LoadFrom(env, file.OpenReadStream());
+ }
+
+ // Create prediction engine and test predictions.
+ var engine = new MyPredictionEngine(env, loadedModel);
+
+ // Take a couple examples out of the test data and run predictions on top.
+ var testData = model.Reader.Read(new MultiFileSource(GetDataPath(SentimentTestPath)))
+ .AsEnumerable(env, false);
+ foreach (var input in testData.Take(5))
+ {
+ var prediction = engine.Predict(input);
+ // Verify that predictions match and scores are separated from zero.
+ Assert.Equal(input.Sentiment, prediction.Sentiment);
+ Assert.True(input.Sentiment && prediction.Score > 1 || !input.Sentiment && prediction.Score < -1);
+ }
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs
new file mode 100644
index 0000000000..2e687c9138
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs
@@ -0,0 +1,46 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Learners;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Train with initial predictor: Similar to the simple train scenario, but also accept a pre-trained initial model.
+ /// The scenario might be one of the online linear learners that can take advantage of this, e.g., averaged perceptron.
+ ///
+ [Fact]
+ public void New_TrainWithInitialPredictor()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()));
+
+ // Train the pipeline, prepare train set.
+ var reader = pipeline.Fit(new MultiFileSource(dataPath));
+ var trainData = reader.Read(new MultiFileSource(dataPath));
+
+
+ // Train the first predictor.
+ var trainer = new MySdca(env, new LinearClassificationTrainer.Arguments
+ {
+ NumThreads = 1
+ }, "Features", "Label");
+ var firstModel = trainer.Fit(trainData);
+
+ // Train the second predictor on the same data.
+ var secondTrainer = new MyAveragedPerceptron(env, new AveragedPerceptronTrainer.Arguments(), "Features", "Label");
+ var finalModel = secondTrainer.Train(trainData, firstModel.InnerModel);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs
new file mode 100644
index 0000000000..25eebafc45
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs
@@ -0,0 +1,39 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Train with validation set: Similar to the simple train scenario, but also support a validation set.
+ /// The learner might be trees with early stopping.
+ ///
+ [Fact]
+ public void New_TrainWithValidationSet()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ var validationDataPath = GetDataPath(SentimentTestPath);
+
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ // Pipeline.
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()));
+
+ // Train the pipeline, prepare train and validation set.
+ var reader = pipeline.Fit(new MultiFileSource(dataPath));
+ var trainData = reader.Read(new MultiFileSource(dataPath));
+ var validData = reader.Read(new MultiFileSource(validationDataPath));
+
+ // Train model with validation set.
+ var trainer = new MySdca(env, new Runtime.Learners.LinearClassificationTrainer.Arguments(), "Features", "Label");
+ var model = trainer.Train(trainData, validData);
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs
new file mode 100644
index 0000000000..d0f79cc1f1
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs
@@ -0,0 +1,59 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
+using Xunit;
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ public partial class ApiScenariosTests
+ {
+ ///
+ /// Visibility: It should, possibly through the debugger, be not such a pain to actually
+ /// see what is happening to your data when you apply this or that transform. E.g.: if I
+ /// were to have the text "Help I'm a bug!" I should be able to see the steps where it is
+ /// normalized to "help i'm a bug" then tokenized into ["help", "i'm", "a", "bug"] then
+ /// mapped into term numbers [203, 25, 3, 511] then projected into the sparse
+ /// float vector {3:1, 25:1, 203:1, 511:1}, etc. etc.
+ ///
+ [Fact]
+ void New_Visibility()
+ {
+ var dataPath = GetDataPath(SentimentDataPath);
+ using (var env = new TlcEnvironment(seed: 1, conc: 1))
+ {
+ var pipeline = new MyTextLoader(env, MakeSentimentTextLoaderArgs())
+ .Append(new MyTextTransform(env, MakeSentimentTextTransformArgs()));
+ var data = pipeline.FitAndRead(new MultiFileSource(dataPath));
+ // In order to find out available column names, you can go through schema and check
+ // column names and appropriate type for getter.
+ for (int i = 0; i < data.Schema.ColumnCount; i++)
+ {
+ var columnName = data.Schema.GetColumnName(i);
+ var columnType = data.Schema.GetColumnType(i).RawType;
+ }
+
+ using (var cursor = data.GetRowCursor(x => true))
+ {
+ Assert.True(cursor.Schema.TryGetColumnIndex("SentimentText", out int textColumn));
+ Assert.True(cursor.Schema.TryGetColumnIndex("Features_TransformedText", out int transformedTextColumn));
+ Assert.True(cursor.Schema.TryGetColumnIndex("Features", out int featureColumn));
+
+ var originalTextGettter = cursor.GetGetter(textColumn);
+ var transformedTextGettter = cursor.GetGetter>(transformedTextColumn);
+ var featureGettter = cursor.GetGetter>(featureColumn);
+ DvText text = default;
+ VBuffer transformedText = default;
+ VBuffer features = default;
+ while (cursor.MoveNext())
+ {
+ originalTextGettter(ref text);
+ transformedTextGettter(ref transformedText);
+ featureGettter(ref features);
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs
new file mode 100644
index 0000000000..1c2156b3e5
--- /dev/null
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Wrappers.cs
@@ -0,0 +1,685 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Core.Data;
+using Microsoft.ML.Models;
+using Microsoft.ML.Runtime;
+using Microsoft.ML.Runtime.Api;
+using Microsoft.ML.Runtime.CommandLine;
+using Microsoft.ML.Runtime.Data;
+using Microsoft.ML.Runtime.Data.IO;
+using Microsoft.ML.Runtime.EntryPoints;
+using Microsoft.ML.Runtime.Internal.Internallearn;
+using Microsoft.ML.Runtime.Learners;
+using Microsoft.ML.Runtime.Model;
+using Microsoft.ML.Runtime.Training;
+using Microsoft.ML.Tests.Scenarios.Api;
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+
+[assembly: LoadableClass(typeof(TransformWrapper), null, typeof(SignatureLoadModel),
+ "Transform wrapper", TransformWrapper.LoaderSignature)]
+[assembly: LoadableClass(typeof(LoaderWrapper), null, typeof(SignatureLoadModel),
+ "Loader wrapper", LoaderWrapper.LoaderSignature)]
+
+namespace Microsoft.ML.Tests.Scenarios.Api
+{
+ using TScalarPredictor = IPredictorProducing;
+ using TWeightsPredictor = IPredictorWithFeatureWeights;
+
+ public sealed class LoaderWrapper : IDataReader, ICanSaveModel
+ {
+ public const string LoaderSignature = "LoaderWrapper";
+
+ private readonly IHostEnvironment _env;
+ private readonly Func _loaderFactory;
+
+ public LoaderWrapper(IHostEnvironment env, Func loaderFactory)
+ {
+ _env = env;
+ _loaderFactory = loaderFactory;
+ }
+
+ public ISchema GetOutputSchema()
+ {
+ var emptyData = Read(new MultiFileSource(null));
+ return emptyData.Schema;
+ }
+
+ public IDataView Read(IMultiStreamSource input) => _loaderFactory(input);
+
+ public void Save(ModelSaveContext ctx)
+ {
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+ var ldr = Read(new MultiFileSource(null));
+ ctx.SaveModel(ldr, "Loader");
+ }
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "LDR WRPR",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ public LoaderWrapper(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ ctx.CheckAtModel(GetVersionInfo());
+ ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null));
+
+ var loaderStream = new MemoryStream();
+ using (var rep = RepositoryWriter.CreateNew(loaderStream))
+ {
+ ModelSaveContext.SaveModel(rep, loader, "Loader");
+ rep.Commit();
+ }
+
+ _env = env;
+ _loaderFactory = (IMultiStreamSource source) =>
+ {
+ using (var rep = RepositoryReader.Open(loaderStream))
+ {
+ ModelLoadContext.LoadModel(env, out var ldr, rep, "Loader", source);
+ return ldr;
+ }
+ };
+
+ }
+ }
+
+ public class TransformWrapper : ITransformer, ICanSaveModel
+ {
+ public const string LoaderSignature = "TransformWrapper";
+ private const string TransformDirTemplate = "Step_{0:000}";
+
+ protected readonly IHostEnvironment _env;
+ protected readonly IDataView _xf;
+
+ public TransformWrapper(IHostEnvironment env, IDataView xf)
+ {
+ _env = env;
+ _xf = xf;
+ }
+
+ public ISchema GetOutputSchema(ISchema inputSchema)
+ {
+ var dv = new EmptyDataView(_env, inputSchema);
+ var output = ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, dv);
+ return output.Schema;
+ }
+
+ public void Save(ModelSaveContext ctx)
+ {
+ ctx.CheckAtModel();
+ ctx.SetVersionInfo(GetVersionInfo());
+
+ var dataPipe = _xf;
+ var transforms = new List();
+ while (dataPipe is IDataTransform xf)
+ {
+ // REVIEW: a malicious user could construct a loop in the Source chain, that would
+ // cause this method to iterate forever (and throw something when the list overflows). There's
+ // no way to insulate from ALL malicious behavior.
+ transforms.Add(xf);
+ dataPipe = xf.Source;
+ Contracts.AssertValue(dataPipe);
+ }
+ transforms.Reverse();
+
+ ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_env, c, dataPipe.Schema));
+
+ ctx.Writer.Write(transforms.Count);
+ for (int i = 0; i < transforms.Count; i++)
+ {
+ var dirName = string.Format(TransformDirTemplate, i);
+ ctx.SaveModel(transforms[i], dirName);
+ }
+ }
+
+ private static VersionInfo GetVersionInfo()
+ {
+ return new VersionInfo(
+ modelSignature: "XF WRPR",
+ verWrittenCur: 0x00010001, // Initial
+ verReadableCur: 0x00010001,
+ verWeCanReadBack: 0x00010001,
+ loaderSignature: LoaderSignature);
+ }
+
+ public TransformWrapper(IHostEnvironment env, ModelLoadContext ctx)
+ {
+ ctx.CheckAtModel(GetVersionInfo());
+ int n = ctx.Reader.ReadInt32();
+
+ ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null));
+
+ IDataView data = loader;
+ for (int i = 0; i < n; i++)
+ {
+ var dirName = string.Format(TransformDirTemplate, i);
+ ctx.LoadModel(env, out var xf, dirName, data);
+ data = xf;
+ }
+
+ _env = env;
+ _xf = data;
+ }
+
+ public IDataView Transform(IDataView input) => ApplyTransformUtils.ApplyAllTransformsToData(_env, _xf, input);
+ }
+
+ public interface IPredictorTransformer : ITransformer
+ {
+ TModel InnerModel { get; }
+ }
+
+ public class ScorerWrapper : TransformWrapper, IPredictorTransformer
+ where TModel : IPredictor
+ {
+ protected readonly string _featureColumn;
+
+ public ScorerWrapper(IHostEnvironment env, IDataView scorer, TModel trainedModel, string featureColumn)
+ : base(env, scorer)
+ {
+ _featureColumn = featureColumn;
+ InnerModel = trainedModel;
+ }
+
+ public TModel InnerModel { get; }
+ }
+
+ public class BinaryScorerWrapper : ScorerWrapper
+ where TModel : IPredictor
+ {
+ public BinaryScorerWrapper(IHostEnvironment env, TModel model, ISchema inputSchema, string featureColumn, BinaryClassifierScorer.Arguments args)
+ : base(env, MakeScorer(env, inputSchema, featureColumn, model, args), model, featureColumn)
+ {
+ }
+
+ private static IDataView MakeScorer(IHostEnvironment env, ISchema schema, string featureColumn, TModel model, BinaryClassifierScorer.Arguments args)
+ {
+ var settings = $"Binary{{{CmdParser.GetSettings(env, args, new BinaryClassifierScorer.Arguments())}}}";
+
+ var scorerFactorySettings = CmdParser.CreateComponentFactory(
+ typeof(IComponentFactory),
+ typeof(SignatureDataScorer),
+ settings);
+
+ var bindable = ScoreUtils.GetSchemaBindableMapper(env, model, scorerFactorySettings: scorerFactorySettings);
+ var edv = new EmptyDataView(env, schema);
+ var data = new RoleMappedData(edv, "Label", featureColumn, opt: true);
+
+ return new BinaryClassifierScorer(env, args, data.Data, bindable.Bind(env, data.Schema), data.Schema);
+ }
+
+ public BinaryScorerWrapper Clone(BinaryClassifierScorer.Arguments scorerArgs)
+ {
+ var scorer = _xf as IDataScorerTransform;
+ return new BinaryScorerWrapper(_env, InnerModel, scorer.Source.Schema, _featureColumn, scorerArgs);
+ }
+ }
+
+ public class MyTextLoader : IDataReaderEstimator
+ {
+ private readonly TextLoader.Arguments _args;
+ private readonly IHostEnvironment _env;
+
+ public MyTextLoader(IHostEnvironment env, TextLoader.Arguments args)
+ {
+ _env = env;
+ _args = args;
+ }
+
+ public LoaderWrapper Fit(IMultiStreamSource input)
+ {
+ return new LoaderWrapper(_env, x => new TextLoader(_env, _args, x));
+ }
+
+ public SchemaShape GetOutputSchema()
+ {
+ var emptyData = new TextLoader(_env, _args, new MultiFileSource(null));
+ return SchemaShape.Create(emptyData.Schema);
+ }
+ }
+
+ public interface ITrainerEstimator: IEstimator
+ where TTransformer: IPredictorTransformer
+ where TModel: IPredictor
+ {
+ TrainerInfo TrainerInfo { get; }
+ }
+
+ public abstract class TrainerBase : ITrainerEstimator
+ where TTransformer : ScorerWrapper
+ where TModel : IPredictor
+ {
+ protected readonly IHostEnvironment _env;
+ protected readonly string _featureCol;
+ protected readonly string _labelCol;
+
+ public TrainerInfo TrainerInfo { get; }
+
+ protected TrainerBase(IHostEnvironment env, TrainerInfo trainerInfo, string featureColumn, string labelColumn)
+ {
+ _env = env;
+ _featureCol = featureColumn;
+ _labelCol = labelColumn;
+ TrainerInfo = trainerInfo;
+ }
+
+ public TTransformer Fit(IDataView input)
+ {
+ return TrainTransformer(input);
+ }
+
+ protected TTransformer TrainTransformer(IDataView trainSet,
+ IDataView validationSet = null, IPredictor initPredictor = null)
+ {
+ var cachedTrain = TrainerInfo.WantCaching ? new CacheDataView(_env, trainSet, prefetch: null) : trainSet;
+
+ var trainRoles = new RoleMappedData(cachedTrain, label: _labelCol, feature: _featureCol);
+ var emptyData = new EmptyDataView(_env, trainSet.Schema);
+ IDataView normalizer = emptyData;
+
+ if (TrainerInfo.NeedNormalization && trainRoles.Schema.FeaturesAreNormalized() == false)
+ {
+ var view = NormalizeTransform.CreateMinMaxNormalizer(_env, trainRoles.Data, name: trainRoles.Schema.Feature.Name);
+ normalizer = ApplyTransformUtils.ApplyAllTransformsToData(_env, view, emptyData, cachedTrain);
+
+ trainRoles = new RoleMappedData(view, trainRoles.Schema.GetColumnRoleNames());
+ }
+
+ RoleMappedData validRoles;
+
+ if (validationSet == null)
+ validRoles = null;
+ else
+ {
+ var cachedValid = TrainerInfo.WantCaching ? new CacheDataView(_env, validationSet, prefetch: null) : validationSet;
+ cachedValid = ApplyTransformUtils.ApplyAllTransformsToData(_env, normalizer, cachedValid);
+ validRoles = new RoleMappedData(cachedValid, label: _labelCol, feature: _featureCol);
+ }
+
+ var pred = TrainCore(new TrainContext(trainRoles, validRoles, initPredictor));
+
+ var scoreRoles = new RoleMappedData(normalizer, label: _labelCol, feature: _featureCol);
+ return MakeScorer(pred, scoreRoles);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ throw new NotImplementedException();
+ }
+
+ protected abstract TModel TrainCore(TrainContext trainContext);
+
+ protected abstract TTransformer MakeScorer(TModel predictor, RoleMappedData data);
+
+ protected ScorerWrapper MakeScorerBasic(TModel predictor, RoleMappedData data)
+ {
+ var scorer = ScoreUtils.GetScorer(predictor, data, _env, data.Schema);
+ return (TTransformer)(new ScorerWrapper(_env, scorer, predictor, data.Schema.Feature.Name));
+ }
+ }
+
+ public class MyTextTransform : IEstimator
+ {
+ private readonly IHostEnvironment _env;
+ private readonly TextTransform.Arguments _args;
+
+ public MyTextTransform(IHostEnvironment env, TextTransform.Arguments args)
+ {
+ _env = env;
+ _args = args;
+ }
+
+ public TransformWrapper Fit(IDataView input)
+ {
+ var xf = TextTransform.Create(_env, _args, input);
+ var empty = new EmptyDataView(_env, input.Schema);
+ var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
+ return new TransformWrapper(_env, chunk);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ public class MyTermTransform : IEstimator
+ {
+ private readonly IHostEnvironment _env;
+ private readonly string _column;
+ private readonly string _srcColumn;
+
+ public MyTermTransform(IHostEnvironment env, string column, string srcColumn = null)
+ {
+ _env = env;
+ _column = column;
+ _srcColumn = srcColumn;
+ }
+
+ public TransformWrapper Fit(IDataView input)
+ {
+ var xf = new TermTransform(_env, input, _column, _srcColumn);
+ var empty = new EmptyDataView(_env, input.Schema);
+ var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
+ return new TransformWrapper(_env, chunk);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ public class MyConcatTransform : IEstimator
+ {
+ private readonly IHostEnvironment _env;
+ private readonly string _name;
+ private readonly string[] _source;
+
+ public MyConcatTransform(IHostEnvironment env, string name, params string[] source)
+ {
+ _env = env;
+ _name = name;
+ _source = source;
+ }
+
+ public TransformWrapper Fit(IDataView input)
+ {
+ var xf = new ConcatTransform(_env, input, _name, _source);
+ var empty = new EmptyDataView(_env, input.Schema);
+ var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
+ return new TransformWrapper(_env, chunk);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ public class MyKeyToValueTransform : IEstimator
+ {
+ private readonly IHostEnvironment _env;
+ private readonly string _name;
+ private readonly string _source;
+
+ public MyKeyToValueTransform(IHostEnvironment env, string name, string source = null)
+ {
+ _env = env;
+ _name = name;
+ _source = source;
+ }
+
+ public TransformWrapper Fit(IDataView input)
+ {
+ var xf = new KeyToValueTransform(_env, input, _name, _source);
+ var empty = new EmptyDataView(_env, input.Schema);
+ var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
+ return new TransformWrapper(_env, chunk);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ public sealed class MySdca : TrainerBase, TWeightsPredictor>
+ {
+ private readonly LinearClassificationTrainer.Arguments _args;
+
+ public MySdca(IHostEnvironment env, LinearClassificationTrainer.Arguments args, string featureCol, string labelCol)
+ : base(env, new TrainerInfo(), featureCol, labelCol)
+ {
+ _args = args;
+ }
+
+ protected override TWeightsPredictor TrainCore(TrainContext context) => new LinearClassificationTrainer(_env, _args).Train(context);
+
+ public ITransformer Train(IDataView trainData, IDataView validationData = null) => TrainTransformer(trainData, validationData);
+
+ protected override BinaryScorerWrapper MakeScorer(TWeightsPredictor predictor, RoleMappedData data)
+ => new BinaryScorerWrapper(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments());
+ }
+
+ public sealed class MySdcaMulticlass : TrainerBase, IPredictor>
+ {
+ private readonly SdcaMultiClassTrainer.Arguments _args;
+
+ public MySdcaMulticlass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments args, string featureCol, string labelCol)
+ : base(env, new TrainerInfo(), featureCol, labelCol)
+ {
+ _args = args;
+ }
+
+ protected override ScorerWrapper MakeScorer(IPredictor predictor, RoleMappedData data) => MakeScorerBasic(predictor, data);
+
+ protected override IPredictor TrainCore(TrainContext context) => new SdcaMultiClassTrainer(_env, _args).Train(context);
+ }
+
+ public sealed class MyAveragedPerceptron : TrainerBase, IPredictor>
+ {
+ private readonly AveragedPerceptronTrainer _trainer;
+
+ public MyAveragedPerceptron(IHostEnvironment env, AveragedPerceptronTrainer.Arguments args, string featureCol, string labelCol)
+ : base(env, new TrainerInfo(caching: false), featureCol, labelCol)
+ {
+ _trainer = new AveragedPerceptronTrainer(env, args);
+ }
+
+ protected override IPredictor TrainCore(TrainContext trainContext) => _trainer.Train(trainContext);
+
+ public ITransformer Train(IDataView trainData, IPredictor initialPredictor)
+ {
+ return TrainTransformer(trainData, initPredictor: initialPredictor);
+ }
+
+ protected override BinaryScorerWrapper MakeScorer(IPredictor predictor, RoleMappedData data)
+ => new BinaryScorerWrapper(_env, predictor, data.Data.Schema, _featureCol, new BinaryClassifierScorer.Arguments());
+ }
+
+ public sealed class MyPredictionEngine
+ where TSrc : class
+ where TDst : class, new()
+ {
+ private readonly PredictionEngine _engine;
+
+ public MyPredictionEngine(IHostEnvironment env, ITransformer pipe)
+ {
+ IDataView dv = env.CreateDataView(new TSrc[0]);
+ _engine = env.CreatePredictionEngine(pipe.Transform(dv));
+ }
+
+ public TDst Predict(TSrc example)
+ {
+ return _engine.Predict(example);
+ }
+ }
+
+ public sealed class MyBinaryClassifierEvaluator
+ {
+ private readonly IHostEnvironment _env;
+ private readonly BinaryClassifierEvaluator _evaluator;
+
+ public MyBinaryClassifierEvaluator(IHostEnvironment env, BinaryClassifierEvaluator.Arguments args)
+ {
+ _env = env;
+ _evaluator = new BinaryClassifierEvaluator(env, args);
+ }
+
+ public BinaryClassificationMetrics Evaluate(IDataView data, string labelColumn = DefaultColumnNames.Label,
+ string probabilityColumn = DefaultColumnNames.Probability)
+ {
+ var ci = EvaluateUtils.GetScoreColumnInfo(_env, data.Schema, null, DefaultColumnNames.Score, MetadataUtils.Const.ScoreColumnKind.BinaryClassification);
+ var map = new KeyValuePair[]
+ {
+ RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Probability, probabilityColumn),
+ RoleMappedSchema.CreatePair(MetadataUtils.Const.ScoreValueKind.Score, ci.Name)
+ };
+ var rmd = new RoleMappedData(data, labelColumn, DefaultColumnNames.Features, opt: true, custom: map);
+
+ var metricsDict = _evaluator.Evaluate(rmd);
+ return BinaryClassificationMetrics.FromMetrics(_env, metricsDict["OverallMetrics"], metricsDict["ConfusionMatrix"]).Single();
+ }
+ }
+
+ public static class MyCrossValidation
+ {
+ public sealed class BinaryCrossValidationMetrics
+ {
+ public readonly ITransformer[] FoldModels;
+ public readonly BinaryClassificationMetrics[] FoldMetrics;
+
+ public BinaryCrossValidationMetrics(ITransformer[] models, BinaryClassificationMetrics[] metrics)
+ {
+ FoldModels = models;
+ FoldMetrics = metrics;
+ }
+ }
+
+ public sealed class BinaryCrossValidator
+ {
+ private readonly IHostEnvironment _env;
+
+ public int NumFolds { get; set; } = 2;
+
+ public string StratificationColumn { get; set; }
+
+ public string LabelColumn { get; set; } = DefaultColumnNames.Label;
+
+ public BinaryCrossValidator(IHostEnvironment env)
+ {
+ _env = env;
+ }
+
+ public BinaryCrossValidationMetrics CrossValidate(IDataView trainData, IEstimator estimator)
+ {
+ var models = new ITransformer[NumFolds];
+ var metrics = new BinaryClassificationMetrics[NumFolds];
+
+ if (StratificationColumn == null)
+ {
+ StratificationColumn = "StratificationColumn";
+ var random = new GenerateNumberTransform(_env, trainData, StratificationColumn);
+ trainData = random;
+ }
+ else
+ throw new NotImplementedException();
+
+ var evaluator = new MyBinaryClassifierEvaluator(_env, new BinaryClassifierEvaluator.Arguments() { });
+
+ for (int fold = 0; fold < NumFolds; fold++)
+ {
+ var trainFilter = new RangeFilter(_env, new RangeFilter.Arguments()
+ {
+ Column = StratificationColumn,
+ Min = (Double)fold / NumFolds,
+ Max = (Double)(fold + 1) / NumFolds,
+ Complement = true
+ }, trainData);
+ var testFilter = new RangeFilter(_env, new RangeFilter.Arguments()
+ {
+ Column = StratificationColumn,
+ Min = (Double)fold / NumFolds,
+ Max = (Double)(fold + 1) / NumFolds,
+ Complement = false
+ }, trainData);
+
+ models[fold] = estimator.Fit(trainFilter);
+ var scoredTest = models[fold].Transform(testFilter);
+ metrics[fold] = evaluator.Evaluate(scoredTest, labelColumn: LabelColumn, probabilityColumn: "Probability");
+ }
+
+ return new BinaryCrossValidationMetrics(models, metrics);
+
+ }
+ }
+ }
+
+ public class MyLambdaTransform : IEstimator
+ where TSrc : class, new()
+ where TDst : class, new()
+ {
+ private readonly IHostEnvironment _env;
+ private readonly Action _action;
+
+ public MyLambdaTransform(IHostEnvironment env, Action action)
+ {
+ _env = env;
+ _action = action;
+ }
+
+ public TransformWrapper Fit(IDataView input)
+ {
+ var xf = LambdaTransform.CreateMap(_env, input, _action);
+ var empty = new EmptyDataView(_env, input.Schema);
+ var chunk = ApplyTransformUtils.ApplyAllTransformsToData(_env, xf, empty, input);
+ return new TransformWrapper(_env, chunk);
+ }
+
+ public SchemaShape GetOutputSchema(SchemaShape inputSchema)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ public sealed class MyOva : TrainerBase, OvaPredictor>
+ {
+ private readonly ITrainerEstimator, TScalarPredictor> _binaryEstimator;
+
+ public MyOva(IHostEnvironment env, ITrainerEstimator, TScalarPredictor> estimator,
+ string featureColumn = DefaultColumnNames.Features, string labelColumn = DefaultColumnNames.Label)
+ : base(env, MakeTrainerInfo(estimator), featureColumn, labelColumn)
+ {
+ _binaryEstimator = estimator;
+ }
+
+ private static TrainerInfo MakeTrainerInfo(ITrainerEstimator, TScalarPredictor> estimator)
+ => new TrainerInfo(estimator.TrainerInfo.NeedNormalization, estimator.TrainerInfo.NeedCalibration, false);
+
+ protected override ScorerWrapper MakeScorer(OvaPredictor predictor, RoleMappedData data)
+ => MakeScorerBasic(predictor, data);
+
+ protected override OvaPredictor TrainCore(TrainContext trainContext)
+ {
+ var trainRoles = trainContext.TrainingSet;
+ trainRoles.CheckMultiClassLabel(out var numClasses);
+
+ var predictors = new IPredictorTransformer[numClasses];
+ for (int iClass = 0; iClass < numClasses; iClass++)
+ {
+ var data = new LabelIndicatorTransform(_env, trainRoles.Data, iClass, "Label");
+ predictors[iClass] = _binaryEstimator.Fit(data);
+ }
+ var prs = predictors.Select(x => x.InnerModel);
+ var finalPredictor = OvaPredictor.Create(_env.Register("ova"), prs.ToArray());
+ return finalPredictor;
+ }
+ }
+
+ public static class MyHelperExtensions
+ {
+ public static void SaveAsBinary(this IDataView data, IHostEnvironment env, Stream stream)
+ {
+ var saver = new BinarySaver(env, new BinarySaver.Arguments());
+ using (var ch = env.Start("SaveData"))
+ DataSaverUtils.SaveDataView(ch, saver, data, stream);
+ }
+
+ public static IDataView FitAndTransform(this IEstimator est, IDataView data) => est.Fit(data).Transform(data);
+
+ public static IDataView FitAndRead(this IDataReaderEstimator> est, TSource source)
+ => est.Fit(source).Read(source);
+ }
+}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs
index 00e4f9b703..b1be31e89d 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/IntrospectiveTraining.cs
@@ -38,7 +38,7 @@ private TOut GetValue(Dictionary keyValues, string key)
///
[Fact]
- void IntrospectiveTraining()
+ public void IntrospectiveTraining()
{
var dataPath = GetDataPath(SentimentDataPath);
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs
index 811f857046..d8e2fa8079 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Metacomponents.cs
@@ -1,10 +1,12 @@
-using Microsoft.ML.Runtime;
-using Microsoft.ML.Runtime.Api;
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.EntryPoints;
using Microsoft.ML.Runtime.FastTree;
using Microsoft.ML.Runtime.Learners;
-using System.Linq;
using Xunit;
namespace Microsoft.ML.Tests.Scenarios.Api
@@ -17,7 +19,7 @@ public partial class ApiScenariosTests
/// If they specify a regression or multi-class classifier ideally that should be a compile error.
///
[Fact]
- void Metacomponents()
+ public void Metacomponents()
{
var dataPath = GetDataPath(IrisDataPath);
using (var env = new TlcEnvironment())
@@ -28,7 +30,7 @@ void Metacomponents()
var trainer = new Ova(env, new Ova.Arguments
{
PredictorType = ComponentFactoryUtils.CreateFromFunction(
- (e) => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
+ e => new FastTreeBinaryClassificationTrainer(e, new FastTreeBinaryClassificationTrainer.Arguments()))
});
IDataView trainData = trainer.Info.WantCaching ? (IDataView)new CacheDataView(env, concat, prefetch: null) : concat;
@@ -36,21 +38,7 @@ void Metacomponents()
// Auto-normalization.
NormalizeTransform.CreateIfNeeded(env, ref trainRoles, trainer);
- var predictor = trainer.Train(new Runtime.TrainContext(trainRoles));
-
- var scoreRoles = new RoleMappedData(concat, label: "Label", feature: "Features");
- IDataScorerTransform scorer = ScoreUtils.GetScorer(predictor, scoreRoles, env, trainRoles.Schema);
-
- var keyToValue = new KeyToValueTransform(env, scorer, "PredictedLabel");
- var model = env.CreatePredictionEngine(keyToValue);
-
- var testLoader = new TextLoader(env, MakeIrisTextLoaderArgs(), new MultiFileSource(dataPath));
- var testData = testLoader.AsEnumerable(env, false);
- foreach (var input in testData.Take(20))
- {
- var prediction = model.Predict(input);
- Assert.True(prediction.PredictedLabel == input.Label);
- }
+ var predictor = trainer.Train(new TrainContext(trainRoles));
}
}
}
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs
index 23fab1c1d6..b561f6b3d2 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/MultithreadedPrediction.cs
@@ -1,4 +1,8 @@
-using Microsoft.ML.Runtime.Api;
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Api;
using Microsoft.ML.Runtime.Data;
using Microsoft.ML.Runtime.Learners;
using System.Linq;
diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs
index 97bb01158e..d0c5e91787 100644
--- a/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs
+++ b/test/Microsoft.ML.Tests/Scenarios/Api/Visibility.cs
@@ -1,4 +1,8 @@
-using Microsoft.ML.Runtime.Data;
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+// See the LICENSE file in the project root for more information.
+
+using Microsoft.ML.Runtime.Data;
using Xunit;
namespace Microsoft.ML.Tests.Scenarios.Api