Skip to content

API scenarios implementation with Estimators #688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Aug 21, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 190 additions & 0 deletions src/Microsoft.ML.Core/Data/IEstimator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Runtime;
using Microsoft.ML.Runtime.Data;
using System;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.ML.Core.Data
{
/// <summary>
/// A set of 'requirements' to the incoming schema, as well as a set of 'promises' of the outgoing schema.
/// This is more relaxed than the proper <see cref="ISchema"/>, since it's only a subset of the columns,
/// and also since it doesn't specify exact <see cref="ColumnType"/>'s for vectors and keys.
/// </summary>
public sealed class SchemaShape
{
public readonly Column[] Columns;

public sealed class Column
{
public enum VectorKind
{
Scalar,
Vector,
VariableVector
}

public readonly string Name;
public readonly VectorKind Kind;
public readonly DataKind ItemKind;
public readonly bool IsKey;
public readonly string[] MetadataKinds;

public Column(string name, VectorKind vecKind, DataKind itemKind, bool isKey, string[] metadataKinds)
{
Contracts.CheckNonEmpty(name, nameof(name));
Contracts.CheckValue(metadataKinds, nameof(metadataKinds));

Name = name;
Kind = vecKind;
ItemKind = itemKind;
IsKey = isKey;
MetadataKinds = metadataKinds;
}
}

public SchemaShape(Column[] columns)
{
Contracts.CheckValue(columns, nameof(columns));
Columns = columns;
}

/// <summary>
/// Create a schema shape out of the fully defined schema.
/// </summary>
public static SchemaShape Create(ISchema schema)
{
Contracts.CheckValue(schema, nameof(schema));
var cols = new List<Column>();

for (int iCol = 0; iCol < schema.ColumnCount; iCol++)
{
if (!schema.IsHidden(iCol))
{
Column.VectorKind vecKind;
var type = schema.GetColumnType(iCol);
if (type.IsKnownSizeVector)
vecKind = Column.VectorKind.Vector;
else if (type.IsVector)
vecKind = Column.VectorKind.VariableVector;
else
vecKind = Column.VectorKind.Scalar;

var kind = type.ItemType.RawKind;
var isKey = type.ItemType.IsKey;

var metadataNames = schema.GetMetadataTypes(iCol)
.Select(kvp => kvp.Key)
.ToArray();
cols.Add(new Column(schema.GetColumnName(iCol), vecKind, kind, isKey, metadataNames));
}
}
return new SchemaShape(cols.ToArray());
}

/// <summary>
/// Returns the column with a specified <paramref name="name"/>, and <c>null</c> if there is no such column.
/// </summary>
public Column FindColumn(string name)
{
Contracts.CheckValue(name, nameof(name));
return Columns.FirstOrDefault(x => x.Name == name);
}

// REVIEW: I think we should have an IsCompatible method to check if it's OK to use one schema shape
// as an input to another schema shape. I started writing, but realized that there's more than one way to check for
// the 'compatibility': as in, 'CAN be compatible' vs. 'WILL be compatible'.
}

/// <summary>
/// Exception class for schema validation errors.
/// </summary>
public class SchemaException : Exception
Copy link
Contributor

@Ivanidzo4ka Ivanidzo4ka Aug 21, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SchemaException [](start = 17, length = 15)

Since it's not part of Contracts, who will ever throw it? #Resolved

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We all shall, in time.


In reply to: 211687059 [](ancestors = 211687059)

{
}

/// <summary>
/// The 'data reader' takes a certain kind of input and turns it into an <see cref="IDataView"/>.
/// </summary>
/// <typeparam name="TSource">The type of input the reader takes.</typeparam>
public interface IDataReader<in TSource>
{
/// <summary>
/// Produce the data view from the specified input.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual reading happens here, just schema validation.
/// </summary>
IDataView Read(TSource input);

/// <summary>
/// The output schema of the reader.
/// </summary>
ISchema GetOutputSchema();
}

/// <summary>
/// Sometimes we need to 'fit' an <see cref="IDataReader{TIn}"/>.
/// A DataReader estimator is the object that does it.
/// </summary>
public interface IDataReaderEstimator<in TSource, out TReader>
where TReader : IDataReader<TSource>
{
/// <summary>
/// Train and return a data reader.
///
/// REVIEW: you could consider the transformer to take a different <typeparamref name="TSource"/>, but we don't have such components
/// yet, so why complicate matters?
/// </summary>
TReader Fit(TSource input);

/// <summary>
/// The 'promise' of the output schema.
/// It will be used for schema propagation.
/// </summary>
SchemaShape GetOutputSchema();
}

/// <summary>
/// The transformer is a component that transforms data.
/// It also supports 'schema propagation' to answer the question of 'how the data with this schema look after you transform it?'.
/// </summary>
public interface ITransformer
{
/// <summary>
/// Schema propagation for transformers.
/// Returns the output schema of the data, if the input schema is like the one provided.
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the transformer.
/// </summary>
ISchema GetOutputSchema(ISchema inputSchema);

/// <summary>
/// Take the data in, make transformations, output the data.
/// Note that <see cref="IDataView"/>'s are lazy, so no actual transformations happen here, just schema validation.
/// </summary>
IDataView Transform(IDataView input);
}

/// <summary>
/// The estimator (in Spark terminology) is an 'untrained transformer'. It needs to 'fit' on the data to manufacture
/// a transformer.
/// It also provides the 'schema propagation' like transformers do, but over <see cref="SchemaShape"/> instead of <see cref="ISchema"/>.
/// </summary>
public interface IEstimator<out TTransformer>
where TTransformer : ITransformer
{
/// <summary>
/// Train and return a transformer.
/// </summary>
TTransformer Fit(IDataView input);

/// <summary>
/// Schema propagation for estimators.
/// Returns the output schema shape of the estimator, if the input schema shape is like the one provided.
/// Throws <see cref="SchemaException"/> iff the input schema is not valid for the estimator.
/// </summary>
SchemaShape GetOutputSchema(SchemaShape inputSchema);
}
}
111 changes: 111 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeDataReader.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Model;
using System.IO;

namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// This class represents a data reader that applies a transformer chain after reading.
/// It also has methods to save itself to a repository.
/// </summary>
public sealed class CompositeDataReader<TSource, TLastTransformer> : IDataReader<TSource>
where TLastTransformer : class, ITransformer
{
/// <summary>
/// The underlying data reader.
/// </summary>
public readonly IDataReader<TSource> Reader;
/// <summary>
/// The chain of transformers (possibly empty) that are applied to data upon reading.
/// </summary>
public readonly TransformerChain<TLastTransformer> Transformer;

public CompositeDataReader(IDataReader<TSource> reader, TransformerChain<TLastTransformer> transformerChain = null)
{
Contracts.CheckValue(reader, nameof(reader));
Contracts.CheckValueOrNull(transformerChain);

Reader = reader;
Transformer = transformerChain ?? new TransformerChain<TLastTransformer>();
}

public IDataView Read(TSource input)
{
var idv = Reader.Read(input);
idv = Transformer.Transform(idv);
return idv;
}

public ISchema GetOutputSchema()
{
var s = Reader.GetOutputSchema();
return Transformer.GetOutputSchema(s);
}

/// <summary>
/// Append a new transformer to the end.
/// </summary>
/// <returns>The new composite data reader</returns>
public CompositeDataReader<TSource, TNewLast> AppendTransformer<TNewLast>(TNewLast transformer)
where TNewLast : class, ITransformer
{
Contracts.CheckValue(transformer, nameof(transformer));

return new CompositeDataReader<TSource, TNewLast>(Reader, Transformer.Append(transformer));
}

/// <summary>
/// Save the contents to a stream, as a "model file".
/// </summary>
public void SaveTo(IHostEnvironment env, Stream outputStream)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(outputStream, nameof(outputStream));

env.Check(outputStream.CanWrite && outputStream.CanSeek, "Need a writable and seekable stream to save");
using (var ch = env.Start("Saving pipeline"))
{
using (var rep = RepositoryWriter.CreateNew(outputStream, ch))
{
ch.Trace("Saving data reader");
ModelSaveContext.SaveModel(rep, Reader, "Reader");

ch.Trace("Saving transformer chain");
ModelSaveContext.SaveModel(rep, Transformer, TransformerChain.LoaderSignature);
rep.Commit();
}
}
}
}

/// <summary>
/// Utility class to facilitate loading from a stream.
/// </summary>
public static class CompositeDataReader
{
/// <summary>
/// Load the pipeline from stream.
/// </summary>
public static CompositeDataReader<IMultiStreamSource, ITransformer> LoadFrom(IHostEnvironment env, Stream stream)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(stream, nameof(stream));

env.Check(stream.CanRead && stream.CanSeek, "Need a readable and seekable stream to load");
using (var rep = RepositoryReader.Open(stream, env))
using (var ch = env.Start("Loading pipeline"))
{
ch.Trace("Loading data reader");
ModelLoadContext.LoadModel<IDataReader<IMultiStreamSource>, SignatureLoadModel>(env, out var reader, rep, "Reader");

ch.Trace("Loader transformer chain");
ModelLoadContext.LoadModel<TransformerChain<ITransformer>, SignatureLoadModel>(env, out var transformerChain, rep, TransformerChain.LoaderSignature);
return new CompositeDataReader<IMultiStreamSource, ITransformer>(reader, transformerChain);
}
}
}
}
59 changes: 59 additions & 0 deletions src/Microsoft.ML.Data/DataLoadSave/CompositeReaderEstimator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using Microsoft.ML.Core.Data;

namespace Microsoft.ML.Runtime.Data
{
/// <summary>
/// An estimator class for composite data reader.
/// It can be used to build a 'trainable smart data reader', although this pattern is not very common.
/// </summary>
public sealed class CompositeReaderEstimator<TSource, TLastTransformer> : IDataReaderEstimator<TSource, CompositeDataReader<TSource, TLastTransformer>>
where TLastTransformer : class, ITransformer
{
private readonly IDataReaderEstimator<TSource, IDataReader<TSource>> _start;
private readonly EstimatorChain<TLastTransformer> _estimatorChain;

public CompositeReaderEstimator(IDataReaderEstimator<TSource, IDataReader<TSource>> start, EstimatorChain<TLastTransformer> estimatorChain = null)
{
Contracts.CheckValue(start, nameof(start));
Contracts.CheckValueOrNull(estimatorChain);

_start = start;
_estimatorChain = estimatorChain ?? new EstimatorChain<TLastTransformer>();

// REVIEW: enforce that estimator chain can read the reader's schema.
// Right now it throws.
// GetOutputSchema();
}

public CompositeDataReader<TSource, TLastTransformer> Fit(TSource input)
{
var start = _start.Fit(input);
var idv = start.Read(input);

var xfChain = _estimatorChain.Fit(idv);
return new CompositeDataReader<TSource, TLastTransformer>(start, xfChain);
}

public SchemaShape GetOutputSchema()
{
var shape = _start.GetOutputSchema();
return _estimatorChain.GetOutputSchema(shape);
}

/// <summary>
/// Append another estimator to the end.
/// </summary>
public CompositeReaderEstimator<TSource, TNewTrans> Append<TNewTrans>(IEstimator<TNewTrans> estimator)
where TNewTrans : class, ITransformer
{
Contracts.CheckValue(estimator, nameof(estimator));

return new CompositeReaderEstimator<TSource, TNewTrans>(_start, _estimatorChain.Append(estimator));
}
}

}
Loading