-
Notifications
You must be signed in to change notification settings - Fork 1.9k
API scenarios implementation with Estimators #688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
e05ca2d
Squashed commit of the following:
8fdeae8
add extensibility, mutlithread prediction and visibility
80949a0
PR comments
6072b20
Some fixes too
48dbe68
Some code quality improvements (more to come)
7bd3e9b
Code quality
aee4720
Added OVA
fdb2f00
Added introspective training example.
007b0ce
Fixed OVA normalization
1d54d46
Merged from master
b3b9fc1
Merge remote-tracking branch 'upstream/master' into feature/api-est-s…
1b232d2
Merge
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Runtime; | ||
using Microsoft.ML.Runtime.Data; | ||
using System; | ||
using System.Collections.Generic; | ||
using System.Linq; | ||
|
||
namespace Microsoft.ML.Core.Data | ||
{ | ||
/// <summary> | ||
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema. | ||
/// This is more relaxed than the proper <see cref="ISchema"/>, since it's only a subset of the columns, | ||
/// and also since it doesn't specify exact <see cref="ColumnType"/>'s for vectors and keys. | ||
/// </summary> | ||
public sealed class SchemaShape | ||
{ | ||
public readonly Column[] Columns; | ||
|
||
public sealed class Column | ||
{ | ||
public enum VectorKind | ||
{ | ||
Scalar, | ||
Vector, | ||
VariableVector | ||
} | ||
|
||
public readonly string Name; | ||
public readonly VectorKind Kind; | ||
public readonly DataKind ItemKind; | ||
public readonly bool IsKey; | ||
public readonly string[] MetadataKinds; | ||
|
||
public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds) | ||
{ | ||
Contracts.CheckNonEmpty(name, nameof(name)); | ||
Contracts.CheckValue(metadataKinds, nameof(metadataKinds)); | ||
|
||
Name = name; | ||
Kind = vecKind; | ||
ItemKind = itemKind; | ||
IsKey = isKey; | ||
MetadataKinds = metadataKinds; | ||
} | ||
} | ||
|
||
public SchemaShape(Column[] columns) | ||
{ | ||
Contracts.CheckValue(columns, nameof(columns)); | ||
Columns = columns; | ||
} | ||
|
||
/// <summary> | ||
/// Create a schema shape out of the fully defined schema. | ||
/// </summary> | ||
public static SchemaShape Create(ISchema schema) | ||
{ | ||
Contracts.CheckValue(schema, nameof(schema)); | ||
var cols = new List<Column>(); | ||
|
||
for (int iCol = 0; iCol < schema.ColumnCount; iCol++) | ||
{ | ||
if (!schema.IsHidden(iCol)) | ||
{ | ||
Column.VectorKind vecKind; | ||
var type = schema.GetColumnType(iCol); | ||
if (type.IsKnownSizeVector) | ||
vecKind = Column.VectorKind.Vector; | ||
else if (type.IsVector) | ||
vecKind = Column.VectorKind.VariableVector; | ||
else | ||
vecKind = Column.VectorKind.Scalar; | ||
|
||
var kind = type.ItemType.RawKind; | ||
var isKey = type.ItemType.IsKey; | ||
|
||
var metadataNames = schema.GetMetadataTypes(iCol) | ||
.Select(kvp => kvp.Key) | ||
.ToArray(); | ||
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames)); | ||
} | ||
} | ||
return new SchemaShape(cols.ToArray()); | ||
} | ||
|
||
/// <summary> | ||
/// Returns the column with a specified <paramref name="name"/>, and <c>null</c> if there is no such column. | ||
/// </summary> | ||
public Column FindColumn(string name) | ||
{ | ||
Contracts.CheckValue(name, nameof(name)); | ||
return Columns.FirstOrDefault(x => x.Name == name); | ||
} | ||
|
||
// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape | ||
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for | ||
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'. | ||
} | ||
|
||
/// <summary> | ||
/// Exception class for schema validation errors. | ||
/// </summary> | ||
public class SchemaException : Exception | ||
{ | ||
} | ||
|
||
/// <summary> | ||
/// The 'data reader' takes a certain kind of input and turns it into an <see cref="IDataView"/>. | ||
/// </summary> | ||
/// <typeparam name="TSource">The type of input the reader takes.</typeparam> | ||
public interface IDataReader<in TSource> | ||
{ | ||
/// <summary> | ||
/// Produce the data view from the specified input. | ||
/// Note that <see cref="IDataView"/>'s are lazy, so no actual reading happens here, just schema validation. | ||
/// </summary> | ||
IDataView Read(TSource input); | ||
|
||
/// <summary> | ||
/// The output schema of the reader. | ||
/// </summary> | ||
ISchema GetOutputSchema(); | ||
} | ||
|
||
/// <summary> | ||
/// Sometimes we need to 'fit' an <see cref="IDataReader{TIn}"/>. | ||
/// A DataReader estimator is the object that does it. | ||
/// </summary> | ||
public interface IDataReaderEstimator<in TSource, out TReader> | ||
where TReader : IDataReader<TSource> | ||
{ | ||
/// <summary> | ||
/// Train and return a data reader. | ||
/// | ||
/// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components | ||
/// yet, so why complicate matters? | ||
/// </summary> | ||
TReader Fit(TSource input); | ||
|
||
/// <summary> | ||
/// The 'promise' of the output schema. | ||
/// It will be used for schema propagation. | ||
/// </summary> | ||
SchemaShape GetOutputSchema(); | ||
} | ||
|
||
/// <summary> | ||
/// The transformer is a component that transforms data. | ||
/// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'. | ||
/// </summary> | ||
public interface ITransformer | ||
{ | ||
/// <summary> | ||
/// Schema propagation for transformers. | ||
/// Returns the output schema of the data, if the input schema is like the one provided. | ||
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the transformer. | ||
/// </summary> | ||
ISchema GetOutputSchema(ISchema inputSchema); | ||
|
||
/// <summary> | ||
/// Take the data in, make transformations, output the data. | ||
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation. | ||
/// </summary> | ||
IDataView Transform(IDataView input); | ||
} | ||
|
||
/// <summary> | ||
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture | ||
/// a transformer. | ||
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="ISchema"/>. | ||
/// </summary> | ||
public interface IEstimator<out TTransformer> | ||
where TTransformer : ITransformer | ||
{ | ||
/// <summary> | ||
/// Train and return a transformer. | ||
/// </summary> | ||
TTransformer Fit(IDataView input); | ||
|
||
/// <summary> | ||
/// Schema propagation for estimators. | ||
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided. | ||
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the estimator. | ||
/// </summary> | ||
SchemaShape GetOutputSchema(SchemaShape inputSchema); | ||
} | ||
} |
111 changes: 111 additions & 0 deletions
111
src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
using Microsoft.ML.Runtime.Model; | ||
using System.IO; | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
/// <summary> | ||
/// This class represents a data reader that applies a transformer chain after reading. | ||
/// It also has methods to save itself to a repository. | ||
/// </summary> | ||
public sealed class CompositeDataReader<TSource, TLastTransformer> : IDataReader<TSource> | ||
where TLastTransformer : class, ITransformer | ||
{ | ||
/// <summary> | ||
/// The underlying data reader. | ||
/// </summary> | ||
public readonly IDataReader<TSource> Reader; | ||
/// <summary> | ||
/// The chain of transformers (possibly empty) that are applied to data upon reading. | ||
/// </summary> | ||
public readonly TransformerChain<TLastTransformer> Transformer; | ||
|
||
public CompositeDataReader(IDataReader<TSource> reader, TransformerChain<TLastTransformer> transformerChain = null) | ||
{ | ||
Contracts.CheckValue(reader, nameof(reader)); | ||
Contracts.CheckValueOrNull(transformerChain); | ||
|
||
Reader = reader; | ||
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>(); | ||
} | ||
|
||
public IDataView Read(TSource input) | ||
{ | ||
var idv = Reader.Read(input); | ||
idv = Transformer.Transform(idv); | ||
return idv; | ||
} | ||
|
||
public ISchema GetOutputSchema() | ||
{ | ||
var s = Reader.GetOutputSchema(); | ||
return Transformer.GetOutputSchema(s); | ||
} | ||
|
||
/// <summary> | ||
/// Append a new transformer to the end. | ||
/// </summary> | ||
/// <returns>The new composite data reader</returns> | ||
public CompositeDataReader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLast transformer) | ||
where TNewLast : class, ITransformer | ||
{ | ||
Contracts.CheckValue(transformer, nameof(transformer)); | ||
|
||
return new CompositeDataReader<TSource, TNewLast>(Reader, Transformer.Append(transformer)); | ||
} | ||
|
||
/// <summary> | ||
/// Save the contents to a stream, as a "model file". | ||
/// </summary> | ||
public void SaveTo(IHostEnvironment env, Stream outputStream) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(outputStream, nameof(outputStream)); | ||
|
||
env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save"); | ||
using (var ch = env.Start("Saving pipeline")) | ||
{ | ||
using (var rep = RepositoryWriter.CreateNew(outputStream, ch)) | ||
{ | ||
ch.Trace("Saving data reader"); | ||
ModelSaveContext.SaveModel(rep, Reader, "Reader"); | ||
|
||
ch.Trace("Saving transformer chain"); | ||
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature); | ||
rep.Commit(); | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// <summary> | ||
/// Utility class to facilitate loading from a stream. | ||
/// </summary> | ||
public static class CompositeDataReader | ||
{ | ||
/// <summary> | ||
/// Load the pipeline from stream. | ||
/// </summary> | ||
public static CompositeDataReader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream) | ||
{ | ||
Contracts.CheckValue(env, nameof(env)); | ||
env.CheckValue(stream, nameof(stream)); | ||
|
||
env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load"); | ||
using (var rep = RepositoryReader.Open(stream, env)) | ||
using (var ch = env.Start("Loading pipeline")) | ||
{ | ||
ch.Trace("Loading data reader"); | ||
ModelLoadContext.LoadModel<IDataReader<IMultiStreamSource>, SignatureLoadModel>(env, out var reader, rep, "Reader"); | ||
|
||
ch.Trace("Loader transformer chain"); | ||
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature); | ||
return new CompositeDataReader<IMultiStreamSource, ITransformer>(reader, transformerChain); | ||
} | ||
} | ||
} | ||
} |
59 changes: 59 additions & 0 deletions
59
src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
// Licensed to the .NET Foundation under one or more agreements. | ||
// The .NET Foundation licenses this file to you under the MIT license. | ||
// See the LICENSE file in the project root for more information. | ||
|
||
using Microsoft.ML.Core.Data; | ||
|
||
namespace Microsoft.ML.Runtime.Data | ||
{ | ||
/// <summary> | ||
/// An estimator class for composite data reader. | ||
/// It can be used to build a 'trainable smart data reader', although this pattern is not very common. | ||
/// </summary> | ||
public sealed class CompositeReaderEstimator<TSource, TLastTransformer> : IDataReaderEstimator<TSource, CompositeDataReader<TSource, TLastTransformer>> | ||
where TLastTransformer : class, ITransformer | ||
{ | ||
private readonly IDataReaderEstimator<TSource, IDataReader<TSource>> _start; | ||
private readonly EstimatorChain<TLastTransformer> _estimatorChain; | ||
|
||
public CompositeReaderEstimator(IDataReaderEstimator<TSource, IDataReader<TSource>> start, EstimatorChain<TLastTransformer> estimatorChain = null) | ||
{ | ||
Contracts.CheckValue(start, nameof(start)); | ||
Contracts.CheckValueOrNull(estimatorChain); | ||
|
||
_start = start; | ||
_estimatorChain = estimatorChain ?? new EstimatorChain<TLastTransformer>(); | ||
|
||
// REVIEW: enforce that estimator chain can read the reader's schema. | ||
// Right now it throws. | ||
// GetOutputSchema(); | ||
} | ||
|
||
public CompositeDataReader<TSource, TLastTransformer> Fit(TSource input) | ||
{ | ||
var start = _start.Fit(input); | ||
var idv = start.Read(input); | ||
|
||
var xfChain = _estimatorChain.Fit(idv); | ||
return new CompositeDataReader<TSource, TLastTransformer>(start, xfChain); | ||
} | ||
|
||
public SchemaShape GetOutputSchema() | ||
{ | ||
var shape = _start.GetOutputSchema(); | ||
return _estimatorChain.GetOutputSchema(shape); | ||
} | ||
|
||
/// <summary> | ||
/// Append another estimator to the end. | ||
/// </summary> | ||
public CompositeReaderEstimator<TSource, TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator) | ||
where TNewTrans : class, ITransformer | ||
{ | ||
Contracts.CheckValue(estimator, nameof(estimator)); | ||
|
||
return new CompositeReaderEstimator<TSource, TNewTrans>(_start, _estimatorChain.Append(estimator)); | ||
} | ||
} | ||
|
||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since it's not part of Contracts, who will ever throw it? #Resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We all shall, in time.
In reply to: 211687059 [](ancestors = 211687059)